aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/contrib/kfac
diff options
context:
space:
mode:
authorGravatar A. Unique TensorFlower <gardener@tensorflow.org>2018-08-16 06:20:52 -0700
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2018-08-16 06:24:58 -0700
commit938b9a40787028c58fb548fa6ada8c0dd8180f35 (patch)
treeb34f6644ec1be83f9b77f63d4858f5bbc3068ee0 /tensorflow/contrib/kfac
parent26353f9b51091312e7097143aee9c2d05e2011fd (diff)
Automated rollback of commit 26353f9b51091312e7097143aee9c2d05e2011fd
PiperOrigin-RevId: 208973995
Diffstat (limited to 'tensorflow/contrib/kfac')
-rw-r--r--tensorflow/contrib/kfac/BUILD26
-rw-r--r--tensorflow/contrib/kfac/README.md93
-rw-r--r--tensorflow/contrib/kfac/__init__.py46
-rw-r--r--tensorflow/contrib/kfac/examples/BUILD80
-rw-r--r--tensorflow/contrib/kfac/examples/convnet.py667
-rw-r--r--tensorflow/contrib/kfac/examples/convnet_mnist_distributed_main.py62
-rw-r--r--tensorflow/contrib/kfac/examples/convnet_mnist_multi_tower_main.py48
-rw-r--r--tensorflow/contrib/kfac/examples/convnet_mnist_single_main.py39
-rw-r--r--tensorflow/contrib/kfac/examples/mlp.py354
-rw-r--r--tensorflow/contrib/kfac/examples/mlp_mnist_main.py64
-rw-r--r--tensorflow/contrib/kfac/examples/mnist.py69
-rw-r--r--tensorflow/contrib/kfac/examples/tests/BUILD52
-rw-r--r--tensorflow/contrib/kfac/examples/tests/convnet_test.py166
-rw-r--r--tensorflow/contrib/kfac/examples/tests/mlp_test.py63
-rw-r--r--tensorflow/contrib/kfac/examples/tests/mnist_test.py72
-rw-r--r--tensorflow/contrib/kfac/g3doc/autoencoder.pngbin0 -> 54204 bytes
-rw-r--r--tensorflow/contrib/kfac/python/kernel_tests/BUILD160
-rw-r--r--tensorflow/contrib/kfac/python/kernel_tests/estimator_test.py310
-rw-r--r--tensorflow/contrib/kfac/python/kernel_tests/fisher_blocks_test.py1018
-rw-r--r--tensorflow/contrib/kfac/python/kernel_tests/fisher_factors_test.py955
-rw-r--r--tensorflow/contrib/kfac/python/kernel_tests/layer_collection_test.py597
-rw-r--r--tensorflow/contrib/kfac/python/kernel_tests/loss_functions_test.py190
-rw-r--r--tensorflow/contrib/kfac/python/kernel_tests/op_queue_test.py50
-rw-r--r--tensorflow/contrib/kfac/python/kernel_tests/optimizer_test.py219
-rw-r--r--tensorflow/contrib/kfac/python/kernel_tests/utils_test.py410
-rw-r--r--tensorflow/contrib/kfac/python/ops/BUILD263
-rw-r--r--tensorflow/contrib/kfac/python/ops/curvature_matrix_vector_products.py183
-rw-r--r--tensorflow/contrib/kfac/python/ops/curvature_matrix_vector_products_lib.py30
-rw-r--r--tensorflow/contrib/kfac/python/ops/estimator.py516
-rw-r--r--tensorflow/contrib/kfac/python/ops/estimator_lib.py31
-rw-r--r--tensorflow/contrib/kfac/python/ops/fisher_blocks.py1752
-rw-r--r--tensorflow/contrib/kfac/python/ops/fisher_blocks_lib.py45
-rw-r--r--tensorflow/contrib/kfac/python/ops/fisher_factors.py1830
-rw-r--r--tensorflow/contrib/kfac/python/ops/fisher_factors_lib.py38
-rw-r--r--tensorflow/contrib/kfac/python/ops/layer_collection.py1269
-rw-r--r--tensorflow/contrib/kfac/python/ops/layer_collection_lib.py46
-rw-r--r--tensorflow/contrib/kfac/python/ops/linear_operator.py95
-rw-r--r--tensorflow/contrib/kfac/python/ops/loss_functions.py754
-rw-r--r--tensorflow/contrib/kfac/python/ops/loss_functions_lib.py39
-rw-r--r--tensorflow/contrib/kfac/python/ops/op_queue.py69
-rw-r--r--tensorflow/contrib/kfac/python/ops/op_queue_lib.py30
-rw-r--r--tensorflow/contrib/kfac/python/ops/optimizer.py727
-rw-r--r--tensorflow/contrib/kfac/python/ops/optimizer_lib.py30
-rw-r--r--tensorflow/contrib/kfac/python/ops/placement.py114
-rw-r--r--tensorflow/contrib/kfac/python/ops/utils.py709
-rw-r--r--tensorflow/contrib/kfac/python/ops/utils_lib.py50
46 files changed, 14429 insertions, 1 deletions
diff --git a/tensorflow/contrib/kfac/BUILD b/tensorflow/contrib/kfac/BUILD
new file mode 100644
index 0000000000..b719046b37
--- /dev/null
+++ b/tensorflow/contrib/kfac/BUILD
@@ -0,0 +1,26 @@
+# Description:
+# Contains KfacOptimizer, an implementation of the K-FAC optimization
+# algorithm in TensorFlow.
+package(default_visibility = ["//visibility:public"])
+
+licenses(["notice"]) # Apache 2.0
+
+exports_files(["LICENSE"])
+
+py_library(
+ name = "kfac",
+ srcs = ["__init__.py"],
+ srcs_version = "PY2AND3",
+ deps = [
+ "//tensorflow/contrib/kfac/python/ops:curvature_matrix_vector_products_lib",
+ "//tensorflow/contrib/kfac/python/ops:fisher_blocks_lib",
+ "//tensorflow/contrib/kfac/python/ops:fisher_estimator_lib",
+ "//tensorflow/contrib/kfac/python/ops:fisher_factors_lib",
+ "//tensorflow/contrib/kfac/python/ops:kfac_optimizer_lib",
+ "//tensorflow/contrib/kfac/python/ops:layer_collection_lib",
+ "//tensorflow/contrib/kfac/python/ops:loss_functions_lib",
+ "//tensorflow/contrib/kfac/python/ops:op_queue_lib",
+ "//tensorflow/contrib/kfac/python/ops:utils_lib",
+ "//tensorflow/python:util",
+ ],
+)
diff --git a/tensorflow/contrib/kfac/README.md b/tensorflow/contrib/kfac/README.md
index 42b91d0313..102626925d 100644
--- a/tensorflow/contrib/kfac/README.md
+++ b/tensorflow/contrib/kfac/README.md
@@ -1,3 +1,94 @@
# K-FAC: Kronecker-Factored Approximate Curvature
-## KFAC moved to third_party/tensorflow_kfac.
+# <font color="red", size=10><u>WARNING: </u></font>
+# ==third_party/tensorflow/contrib/kfac is deprecated. This will be==
+# ==removed on 15-07-2018. <!-- STY:begin_strip_and_replace -->Please import third_party/tensorflow_kfac.==
+# ==<!-- STY:end_strip_and_replace Please check https://github.com/tensorflow/kfac. -->==
+
+**K-FAC in TensorFlow** is an implementation of [K-FAC][kfac-paper], an
+approximate second-order optimization method, in TensorFlow. When applied to
+feedforward and convolutional neural networks, K-FAC can converge `>3.5x`
+faster in `>14x` fewer iterations than SGD with Momentum.
+
+[kfac-paper]: https://arxiv.org/abs/1503.05671
+
+## What is K-FAC?
+
+K-FAC, short for "Kronecker-factored Approximate Curvature", is an approximation
+to the [Natural Gradient][natural_gradient] algorithm designed specifically for
+neural networks. It maintains a block-diagonal approximation to the [Fisher
+Information matrix][fisher_information], whose inverse preconditions the
+gradient.
+
+K-FAC can be used in place of SGD, Adam, and other `Optimizer` implementations.
+Experimentally, K-FAC converges `>3.5x` faster than well-tuned SGD.
+
+Unlike most optimizers, K-FAC exploits structure in the model itself (e.g. "What
+are the weights for layer i?"). As such, you must add some additional code while
+constructing your model to use K-FAC.
+
+[natural_gradient]: http://www.mitpressjournals.org/doi/abs/10.1162/089976698300017746
+[fisher_information]: https://en.wikipedia.org/wiki/Fisher_information#Matrix_form
+
+## Why should I use K-FAC?
+
+K-FAC can take advantage of the curvature of the optimization problem, resulting
+in **faster training**. For an 8-layer Autoencoder, K-FAC converges to the same
+loss as SGD with Momentum in 3.8x fewer seconds and 14.7x fewer updates. See how
+training loss changes as a function of number of epochs, steps, and seconds:
+
+![autoencoder](g3doc/autoencoder.png)
+
+## Is K-FAC for me?
+
+If you have a feedforward or convolutional model for classification that is
+converging too slowly, K-FAC is for you. K-FAC can be used in your model if:
+
+* Your model defines a posterior distribution.
+* Your model uses only fully-connected or convolutional layers (residual
+ connections OK).
+* You are training on CPU or GPU.
+* You can modify model code to register layers with K-FAC.
+
+## How do I use K-FAC?
+
+Using K-FAC requires three steps:
+
+1. Registering layer inputs, weights, and pre-activations with a
+ `LayerCollection`.
+1. Minimizing the loss with a `KfacOptimizer`.
+1. Keeping K-FAC's preconditioner updated.
+
+```python
+# Build model.
+w = tf.get_variable("w", ...)
+b = tf.get_variable("b", ...)
+logits = tf.matmul(x, w) + b
+loss = tf.reduce_mean(
+ tf.nn.softmax_cross_entropy_with_logits(labels=y, logits=logits))
+
+# Register layers.
+layer_collection = LayerCollection()
+layer_collection.register_fully_connected((w, b), x, logits)
+layer_collection.register_categorical_predictive_distribution(logits)
+
+# Construct training ops.
+optimizer = KfacOptimizer(..., layer_collection=layer_collection)
+train_op = optimizer.minimize(loss)
+
+# Minimize loss.
+with tf.Session() as sess:
+ ...
+ sess.run([train_op, optimizer.cov_update_op, optimizer.inv_update_op])
+```
+
+See [`examples/`](https://www.tensorflow.org/code/tensorflow/contrib/kfac/examples/) for runnable, end-to-end illustrations.
+
+## Authors
+
+- Alok Aggarwal
+- Daniel Duckworth
+- James Martens
+- Matthew Johnson
+- Olga Wichrowska
+- Roger Grosse
diff --git a/tensorflow/contrib/kfac/__init__.py b/tensorflow/contrib/kfac/__init__.py
new file mode 100644
index 0000000000..1ea354e6cd
--- /dev/null
+++ b/tensorflow/contrib/kfac/__init__.py
@@ -0,0 +1,46 @@
+# Copyright 2017 The TensorFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+"""Kronecker-factored Approximate Curvature Optimizer."""
+
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+# pylint: disable=unused-import,line-too-long
+from tensorflow.contrib.kfac.python.ops import curvature_matrix_vector_products_lib as curvature_matrix_vector_products
+from tensorflow.contrib.kfac.python.ops import estimator_lib as estimator
+from tensorflow.contrib.kfac.python.ops import fisher_blocks_lib as fisher_blocks
+from tensorflow.contrib.kfac.python.ops import fisher_factors_lib as fisher_factors
+from tensorflow.contrib.kfac.python.ops import layer_collection_lib as layer_collection
+from tensorflow.contrib.kfac.python.ops import loss_functions_lib as loss_functions
+from tensorflow.contrib.kfac.python.ops import op_queue_lib as op_queue
+from tensorflow.contrib.kfac.python.ops import optimizer_lib as optimizer
+from tensorflow.contrib.kfac.python.ops import utils_lib as utils
+from tensorflow.python.util.all_util import remove_undocumented
+# pylint: enable=unused-import,line-too-long
+
+_allowed_symbols = [
+ "curvature_matrix_vector_products",
+ "estimator",
+ "fisher_blocks",
+ "fisher_factors",
+ "layer_collection",
+ "loss_functions",
+ "op_queue",
+ "optimizer",
+ "utils",
+]
+
+remove_undocumented(__name__, allowed_exception_list=_allowed_symbols)
diff --git a/tensorflow/contrib/kfac/examples/BUILD b/tensorflow/contrib/kfac/examples/BUILD
new file mode 100644
index 0000000000..8186fa1c62
--- /dev/null
+++ b/tensorflow/contrib/kfac/examples/BUILD
@@ -0,0 +1,80 @@
+package(default_visibility = [
+ "//learning/brain/contrib/kfac/examples:__subpackages__",
+ "//tensorflow/contrib/kfac/examples:__subpackages__",
+])
+
+licenses(["notice"]) # Apache 2.0
+
+exports_files(["LICENSE"])
+
+py_binary(
+ name = "mlp_mnist_main",
+ srcs = ["mlp_mnist_main.py"],
+ srcs_version = "PY2AND3",
+ deps = [
+ ":mlp",
+ "//tensorflow:tensorflow_py",
+ ],
+)
+
+py_library(
+ name = "mlp",
+ srcs = ["mlp.py"],
+ srcs_version = "PY2AND3",
+ deps = [
+ ":mnist",
+ "//tensorflow:tensorflow_py",
+ ],
+)
+
+py_binary(
+ name = "convnet_mnist_single_main",
+ srcs = ["convnet_mnist_single_main.py"],
+ srcs_version = "PY2AND3",
+ deps = [
+ ":convnet",
+ "//tensorflow:tensorflow_py",
+ ],
+)
+
+py_binary(
+ name = "convnet_mnist_multi_tower_main",
+ srcs = ["convnet_mnist_multi_tower_main.py"],
+ srcs_version = "PY2AND3",
+ deps = [
+ ":convnet",
+ "//tensorflow:tensorflow_py",
+ ],
+)
+
+py_binary(
+ name = "convnet_mnist_distributed_main",
+ srcs = ["convnet_mnist_distributed_main.py"],
+ srcs_version = "PY2AND3",
+ deps = [
+ ":convnet",
+ "//tensorflow:tensorflow_py",
+ ],
+)
+
+py_library(
+ name = "convnet",
+ srcs = ["convnet.py"],
+ srcs_version = "PY2AND3",
+ deps = [
+ ":mlp",
+ ":mnist",
+ "//tensorflow:tensorflow_py",
+ "//third_party/py/numpy",
+ ],
+)
+
+py_library(
+ name = "mnist",
+ srcs = ["mnist.py"],
+ srcs_version = "PY2AND3",
+ deps = [
+ "//tensorflow:tensorflow_py",
+ "//third_party/py/numpy",
+ ],
+)
diff --git a/tensorflow/contrib/kfac/examples/convnet.py b/tensorflow/contrib/kfac/examples/convnet.py
new file mode 100644
index 0000000000..44e01e1aeb
--- /dev/null
+++ b/tensorflow/contrib/kfac/examples/convnet.py
@@ -0,0 +1,667 @@
+# Copyright 2017 The TensorFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+r"""Train a ConvNet on MNIST using K-FAC.
+
+This library fits a 5-layer ConvNet on MNIST using K-FAC. The model has the
+following structure,
+
+- Conv Layer: 5x5 kernel, 16 output channels.
+- Max Pool: 3x3 kernel, stride 2.
+- Conv Layer: 5x5 kernel, 16 output channels.
+- Max Pool: 3x3 kernel, stride 2.
+- Linear: 10 output dims.
+
+After 3k~6k steps, this should reach perfect accuracy on the training set.
+"""
+
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+import os
+
+import numpy as np
+import tensorflow as tf
+
+from tensorflow.contrib.kfac.examples import mlp
+from tensorflow.contrib.kfac.examples import mnist
+from tensorflow.contrib.kfac.python.ops import optimizer as opt
+
+
+lc = tf.contrib.kfac.layer_collection
+oq = tf.contrib.kfac.op_queue
+opt = tf.contrib.kfac.optimizer
+
+__all__ = [
+ "conv_layer",
+ "max_pool_layer",
+ "linear_layer",
+ "build_model",
+ "minimize_loss_single_machine",
+ "distributed_grads_only_and_ops_chief_worker",
+ "distributed_grads_and_ops_dedicated_workers",
+ "train_mnist_single_machine",
+ "train_mnist_distributed_sync_replicas",
+ "train_mnist_multitower"
+]
+
+
+# Inverse update ops will be run every _INVERT_EVRY iterations.
+_INVERT_EVERY = 10
+
+
+def conv_layer(layer_id, inputs, kernel_size, out_channels):
+ """Builds a convolutional layer with ReLU non-linearity.
+
+ Args:
+ layer_id: int. Integer ID for this layer's variables.
+ inputs: Tensor of shape [num_examples, width, height, in_channels]. Each row
+ corresponds to a single example.
+ kernel_size: int. Width and height of the convolution kernel. The kernel is
+ assumed to be square.
+ out_channels: int. Number of output features per pixel.
+
+ Returns:
+ preactivations: Tensor of shape [num_examples, width, height, out_channels].
+ Values of the layer immediately before the activation function.
+ activations: Tensor of shape [num_examples, width, height, out_channels].
+ Values of the layer immediately after the activation function.
+ params: Tuple of (kernel, bias), parameters for this layer.
+ """
+ # TODO(b/67004004): Delete this function and rely on tf.layers exclusively.
+ layer = tf.layers.Conv2D(
+ out_channels,
+ kernel_size=[kernel_size, kernel_size],
+ kernel_initializer=tf.random_normal_initializer(stddev=0.01),
+ padding="SAME",
+ name="conv_%d" % layer_id)
+ preactivations = layer(inputs)
+ activations = tf.nn.relu(preactivations)
+
+ # layer.weights is a list. This converts it a (hashable) tuple.
+ return preactivations, activations, (layer.kernel, layer.bias)
+
+
+def max_pool_layer(layer_id, inputs, kernel_size, stride):
+ """Build a max-pooling layer.
+
+ Args:
+ layer_id: int. Integer ID for this layer's variables.
+ inputs: Tensor of shape [num_examples, width, height, in_channels]. Each row
+ corresponds to a single example.
+ kernel_size: int. Width and height to pool over per input channel. The
+ kernel is assumed to be square.
+ stride: int. Step size between pooling operations.
+
+ Returns:
+ Tensor of shape [num_examples, width/stride, height/stride, out_channels].
+ Result of applying max pooling to 'inputs'.
+ """
+ # TODO(b/67004004): Delete this function and rely on tf.layers exclusively.
+ with tf.variable_scope("pool_%d" % layer_id):
+ return tf.nn.max_pool(
+ inputs, [1, kernel_size, kernel_size, 1], [1, stride, stride, 1],
+ padding="SAME",
+ name="pool")
+
+
+def linear_layer(layer_id, inputs, output_size):
+ """Builds the final linear layer for an MNIST classification problem.
+
+ Args:
+ layer_id: int. Integer ID for this layer's variables.
+ inputs: Tensor of shape [num_examples, width, height, in_channels]. Each row
+ corresponds to a single example.
+ output_size: int. Number of output dims per example.
+
+ Returns:
+ activations: Tensor of shape [num_examples, output_size]. Values of the
+ layer immediately after the activation function.
+ params: Tuple of (weights, bias), parameters for this layer.
+ """
+ # TODO(b/67004004): Delete this function and rely on tf.layers exclusively.
+ pre, _, params = mlp.fc_layer(layer_id, inputs, output_size)
+ return pre, params
+
+
+def build_model(examples, labels, num_labels, layer_collection):
+ """Builds a ConvNet classification model.
+
+ Args:
+ examples: Tensor of shape [num_examples, num_features]. Represents inputs of
+ model.
+ labels: Tensor of shape [num_examples]. Contains integer IDs to be predicted
+ by softmax for each example.
+ num_labels: int. Number of distinct values 'labels' can take on.
+ layer_collection: LayerCollection instance. Layers will be registered here.
+
+ Returns:
+ loss: 0-D Tensor representing loss to be minimized.
+ accuracy: 0-D Tensor representing model's accuracy.
+ """
+ # Build a ConvNet. For each layer with parameters, we'll keep track of the
+ # preactivations, activations, weights, and bias.
+ tf.logging.info("Building model.")
+ pre0, act0, params0 = conv_layer(
+ layer_id=0, inputs=examples, kernel_size=5, out_channels=16)
+ act1 = max_pool_layer(layer_id=1, inputs=act0, kernel_size=3, stride=2)
+ pre2, act2, params2 = conv_layer(
+ layer_id=2, inputs=act1, kernel_size=5, out_channels=16)
+ act3 = max_pool_layer(layer_id=3, inputs=act2, kernel_size=3, stride=2)
+ flat_act3 = tf.reshape(act3, shape=[-1, int(np.prod(act3.shape[1:4]))])
+ logits, params4 = linear_layer(
+ layer_id=4, inputs=flat_act3, output_size=num_labels)
+ loss = tf.reduce_mean(
+ tf.nn.sparse_softmax_cross_entropy_with_logits(
+ labels=labels, logits=logits))
+ accuracy = tf.reduce_mean(
+ tf.cast(tf.equal(labels, tf.argmax(logits, axis=1)), dtype=tf.float32))
+
+ with tf.device("/cpu:0"):
+ tf.summary.scalar("loss", loss)
+ tf.summary.scalar("accuracy", accuracy)
+
+ # Register parameters. K-FAC needs to know about the inputs, outputs, and
+ # parameters of each conv/fully connected layer and the logits powering the
+ # posterior probability over classes.
+ tf.logging.info("Building LayerCollection.")
+ layer_collection.register_conv2d(params0, (1, 1, 1, 1), "SAME", examples,
+ pre0)
+ layer_collection.register_conv2d(params2, (1, 1, 1, 1), "SAME", act1, pre2)
+ layer_collection.register_fully_connected(params4, flat_act3, logits)
+ layer_collection.register_categorical_predictive_distribution(
+ logits, name="logits")
+
+ return loss, accuracy
+
+
+def minimize_loss_single_machine(loss,
+ accuracy,
+ layer_collection,
+ device="/gpu:0",
+ session_config=None):
+ """Minimize loss with K-FAC on a single machine.
+
+ A single Session is responsible for running all of K-FAC's ops. The covariance
+ and inverse update ops are placed on `device`. All model variables are on CPU.
+
+ Args:
+ loss: 0-D Tensor. Loss to be minimized.
+ accuracy: 0-D Tensor. Accuracy of classifier on current minibatch.
+ layer_collection: LayerCollection instance describing model architecture.
+ Used by K-FAC to construct preconditioner.
+ device: string, Either '/cpu:0' or '/gpu:0'. The covariance and inverse
+ update ops are run on this device.
+ session_config: None or tf.ConfigProto. Configuration for tf.Session().
+
+ Returns:
+ final value for 'accuracy'.
+ """
+ # Train with K-FAC.
+ g_step = tf.train.get_or_create_global_step()
+ optimizer = opt.KfacOptimizer(
+ learning_rate=0.0001,
+ cov_ema_decay=0.95,
+ damping=0.001,
+ layer_collection=layer_collection,
+ placement_strategy="round_robin",
+ cov_devices=[device],
+ inv_devices=[device],
+ momentum=0.9)
+ (cov_update_thunks,
+ inv_update_thunks) = optimizer.make_vars_and_create_op_thunks()
+
+ def make_update_op(update_thunks):
+ update_ops = [thunk() for thunk in update_thunks]
+ return tf.group(*update_ops)
+
+ cov_update_op = make_update_op(cov_update_thunks)
+ with tf.control_dependencies([cov_update_op]):
+ inverse_op = tf.cond(
+ tf.equal(tf.mod(g_step, _INVERT_EVERY), 0),
+ lambda: make_update_op(inv_update_thunks), tf.no_op)
+ with tf.control_dependencies([inverse_op]):
+ with tf.device(device):
+ train_op = optimizer.minimize(loss, global_step=g_step)
+
+ tf.logging.info("Starting training.")
+ with tf.train.MonitoredTrainingSession(config=session_config) as sess:
+ while not sess.should_stop():
+ global_step_, loss_, accuracy_, _ = sess.run(
+ [g_step, loss, accuracy, train_op])
+
+ if global_step_ % _INVERT_EVERY == 0:
+ tf.logging.info("global_step: %d | loss: %f | accuracy: %s",
+ global_step_, loss_, accuracy_)
+
+ return accuracy_
+
+
+def _is_gradient_task(task_id, num_tasks):
+ """Returns True if this task should update the weights."""
+ if num_tasks < 3:
+ return True
+ return 0 <= task_id < 0.6 * num_tasks
+
+
+def _is_cov_update_task(task_id, num_tasks):
+ """Returns True if this task should update K-FAC's covariance matrices."""
+ if num_tasks < 3:
+ return False
+ return 0.6 * num_tasks <= task_id < num_tasks - 1
+
+
+def _is_inv_update_task(task_id, num_tasks):
+ """Returns True if this task should update K-FAC's preconditioner."""
+ if num_tasks < 3:
+ return False
+ return task_id == num_tasks - 1
+
+
+def _num_gradient_tasks(num_tasks):
+ """Number of tasks that will update weights."""
+ if num_tasks < 3:
+ return num_tasks
+ return int(np.ceil(0.6 * num_tasks))
+
+
+def _make_distributed_train_op(
+ task_id,
+ num_worker_tasks,
+ num_ps_tasks,
+ layer_collection
+):
+ """Creates optimizer and distributed training op.
+
+ Constructs KFAC optimizer and wraps it in `sync_replicas` optimizer. Makes
+ the train op.
+
+ Args:
+ task_id: int. Integer in [0, num_worker_tasks). ID for this worker.
+ num_worker_tasks: int. Number of workers in this distributed training setup.
+ num_ps_tasks: int. Number of parameter servers holding variables. If 0,
+ parameter servers are not used.
+ layer_collection: LayerCollection instance describing model architecture.
+ Used by K-FAC to construct preconditioner.
+
+ Returns:
+ sync_optimizer: `tf.train.SyncReplicasOptimizer` instance which wraps KFAC
+ optimizer.
+ optimizer: Instance of `opt.KfacOptimizer`.
+ global_step: `tensor`, Global step.
+ """
+ tf.logging.info("Task id : %d", task_id)
+ with tf.device(tf.train.replica_device_setter(num_ps_tasks)):
+ global_step = tf.train.get_or_create_global_step()
+ optimizer = opt.KfacOptimizer(
+ learning_rate=0.0001,
+ cov_ema_decay=0.95,
+ damping=0.001,
+ layer_collection=layer_collection,
+ momentum=0.9)
+ sync_optimizer = tf.train.SyncReplicasOptimizer(
+ opt=optimizer,
+ replicas_to_aggregate=_num_gradient_tasks(num_worker_tasks),
+ total_num_replicas=num_worker_tasks)
+ return sync_optimizer, optimizer, global_step
+
+
+def distributed_grads_only_and_ops_chief_worker(
+ task_id, is_chief, num_worker_tasks, num_ps_tasks, master, checkpoint_dir,
+ loss, accuracy, layer_collection, invert_every=10):
+ """Minimize loss with a synchronous implementation of K-FAC.
+
+ All workers perform gradient computation. Chief worker applies gradient after
+ averaging the gradients obtained from all the workers. All workers block
+ execution until the update is applied. Chief worker runs covariance and
+ inverse update ops. Covariance and inverse matrices are placed on parameter
+ servers in a round robin manner. For further details on synchronous
+ distributed optimization check `tf.train.SyncReplicasOptimizer`.
+
+ Args:
+ task_id: int. Integer in [0, num_worker_tasks). ID for this worker.
+ is_chief: `boolean`, `True` if the worker is chief worker.
+ num_worker_tasks: int. Number of workers in this distributed training setup.
+ num_ps_tasks: int. Number of parameter servers holding variables. If 0,
+ parameter servers are not used.
+ master: string. IP and port of TensorFlow runtime process. Set to empty
+ string to run locally.
+ checkpoint_dir: string or None. Path to store checkpoints under.
+ loss: 0-D Tensor. Loss to be minimized.
+ accuracy: dict mapping strings to 0-D Tensors. Additional accuracy to
+ run with each step.
+ layer_collection: LayerCollection instance describing model architecture.
+ Used by K-FAC to construct preconditioner.
+ invert_every: `int`, Number of steps between update the inverse.
+
+ Returns:
+ final value for 'accuracy'.
+
+ Raises:
+ ValueError: if task_id >= num_worker_tasks.
+ """
+
+ sync_optimizer, optimizer, global_step = _make_distributed_train_op(
+ task_id, num_worker_tasks, num_ps_tasks, layer_collection)
+ (cov_update_thunks,
+ inv_update_thunks) = optimizer.make_vars_and_create_op_thunks()
+
+ tf.logging.info("Starting training.")
+ hooks = [sync_optimizer.make_session_run_hook(is_chief)]
+
+ def make_update_op(update_thunks):
+ update_ops = [thunk() for thunk in update_thunks]
+ return tf.group(*update_ops)
+
+ if is_chief:
+ cov_update_op = make_update_op(cov_update_thunks)
+ with tf.control_dependencies([cov_update_op]):
+ inverse_op = tf.cond(
+ tf.equal(tf.mod(global_step, invert_every), 0),
+ lambda: make_update_op(inv_update_thunks),
+ tf.no_op)
+ with tf.control_dependencies([inverse_op]):
+ train_op = sync_optimizer.minimize(loss, global_step=global_step)
+ else:
+ train_op = sync_optimizer.minimize(loss, global_step=global_step)
+
+ with tf.train.MonitoredTrainingSession(
+ master=master,
+ is_chief=is_chief,
+ checkpoint_dir=checkpoint_dir,
+ hooks=hooks,
+ stop_grace_period_secs=0) as sess:
+ while not sess.should_stop():
+ global_step_, loss_, accuracy_, _ = sess.run(
+ [global_step, loss, accuracy, train_op])
+ tf.logging.info("global_step: %d | loss: %f | accuracy: %s", global_step_,
+ loss_, accuracy_)
+ return accuracy_
+
+
+def distributed_grads_and_ops_dedicated_workers(
+ task_id, is_chief, num_worker_tasks, num_ps_tasks, master, checkpoint_dir,
+ loss, accuracy, layer_collection):
+ """Minimize loss with a synchronous implementation of K-FAC.
+
+ Different workers are responsible for different parts of K-FAC's Ops. The
+ first 60% of tasks compute gradients; the next 20% accumulate covariance
+ statistics; the last 20% invert the matrices used to precondition gradients.
+ The chief worker applies the gradient .
+
+ Args:
+ task_id: int. Integer in [0, num_worker_tasks). ID for this worker.
+ is_chief: `boolean`, `True` if the worker is chief worker.
+ num_worker_tasks: int. Number of workers in this distributed training setup.
+ num_ps_tasks: int. Number of parameter servers holding variables. If 0,
+ parameter servers are not used.
+ master: string. IP and port of TensorFlow runtime process. Set to empty
+ string to run locally.
+ checkpoint_dir: string or None. Path to store checkpoints under.
+ loss: 0-D Tensor. Loss to be minimized.
+ accuracy: dict mapping strings to 0-D Tensors. Additional accuracy to
+ run with each step.
+ layer_collection: LayerCollection instance describing model architecture.
+ Used by K-FAC to construct preconditioner.
+
+ Returns:
+ final value for 'accuracy'.
+
+ Raises:
+ ValueError: if task_id >= num_worker_tasks.
+ """
+ sync_optimizer, optimizer, global_step = _make_distributed_train_op(
+ task_id, num_worker_tasks, num_ps_tasks, layer_collection)
+ _, cov_update_op, inv_update_ops, _, _, _ = optimizer.make_ops_and_vars()
+ train_op = sync_optimizer.minimize(loss, global_step=global_step)
+ inv_update_queue = oq.OpQueue(inv_update_ops)
+
+ tf.logging.info("Starting training.")
+ is_chief = (task_id == 0)
+ hooks = [sync_optimizer.make_session_run_hook(is_chief)]
+ with tf.train.MonitoredTrainingSession(
+ master=master,
+ is_chief=is_chief,
+ checkpoint_dir=checkpoint_dir,
+ hooks=hooks,
+ stop_grace_period_secs=0) as sess:
+ while not sess.should_stop():
+ # Choose which op this task is responsible for running.
+ if _is_gradient_task(task_id, num_worker_tasks):
+ learning_op = train_op
+ elif _is_cov_update_task(task_id, num_worker_tasks):
+ learning_op = cov_update_op
+ elif _is_inv_update_task(task_id, num_worker_tasks):
+ # TODO(duckworthd): Running this op before cov_update_op has been run a
+ # few times can result in "InvalidArgumentError: Cholesky decomposition
+ # was not successful." Delay running this op until cov_update_op has
+ # been run a few times.
+ learning_op = inv_update_queue.next_op(sess)
+ else:
+ raise ValueError("Which op should task %d do?" % task_id)
+
+ global_step_, loss_, accuracy_, _ = sess.run(
+ [global_step, loss, accuracy, learning_op])
+ tf.logging.info("global_step: %d | loss: %f | accuracy: %s", global_step_,
+ loss_, accuracy_)
+
+ return accuracy_
+
+
+def train_mnist_single_machine(data_dir,
+ num_epochs,
+ use_fake_data=False,
+ device="/gpu:0"):
+ """Train a ConvNet on MNIST.
+
+ Args:
+ data_dir: string. Directory to read MNIST examples from.
+ num_epochs: int. Number of passes to make over the training set.
+ use_fake_data: bool. If True, generate a synthetic dataset.
+ device: string, Either '/cpu:0' or '/gpu:0'. The covariance and inverse
+ update ops are run on this device.
+
+ Returns:
+ accuracy of model on the final minibatch of training data.
+ """
+ # Load a dataset.
+ tf.logging.info("Loading MNIST into memory.")
+ examples, labels = mnist.load_mnist(
+ data_dir,
+ num_epochs=num_epochs,
+ batch_size=128,
+ use_fake_data=use_fake_data,
+ flatten_images=False)
+
+ # Build a ConvNet.
+ layer_collection = lc.LayerCollection()
+ loss, accuracy = build_model(
+ examples, labels, num_labels=10, layer_collection=layer_collection)
+
+ # Fit model.
+ return minimize_loss_single_machine(
+ loss, accuracy, layer_collection, device=device)
+
+
+def train_mnist_multitower(data_dir, num_epochs, num_towers,
+ use_fake_data=True, devices=None):
+ """Train a ConvNet on MNIST.
+
+ Training data is split equally among the towers. Each tower computes loss on
+ its own batch of data and the loss is aggregated on the CPU. The model
+ variables are placed on first tower. The covariance and inverse update ops
+ and variables are placed on GPUs in a round robin manner.
+
+ Args:
+ data_dir: string. Directory to read MNIST examples from.
+ num_epochs: int. Number of passes to make over the training set.
+ num_towers: int. Number of CPUs to split inference across.
+ use_fake_data: bool. If True, generate a synthetic dataset.
+ devices: string, Either list of CPU or GPU. The covariance and inverse
+ update ops are run on this device.
+
+ Returns:
+ accuracy of model on the final minibatch of training data.
+ """
+ if devices:
+ device_count = {"GPU": num_towers}
+ else:
+ device_count = {"CPU": num_towers}
+
+ devices = devices or [
+ "/cpu:{}".format(tower_id) for tower_id in range(num_towers)
+ ]
+ # Load a dataset.
+ tf.logging.info("Loading MNIST into memory.")
+ tower_batch_size = 128
+ batch_size = tower_batch_size * num_towers
+ tf.logging.info(
+ ("Loading MNIST into memory. Using batch_size = %d = %d towers * %d "
+ "tower batch size.") % (batch_size, num_towers, tower_batch_size))
+ examples, labels = mnist.load_mnist(
+ data_dir,
+ num_epochs=num_epochs,
+ batch_size=batch_size,
+ use_fake_data=use_fake_data,
+ flatten_images=False)
+
+ # Split minibatch across towers.
+ examples = tf.split(examples, num_towers)
+ labels = tf.split(labels, num_towers)
+
+ # Build an MLP. Each tower's layers will be added to the LayerCollection.
+ layer_collection = lc.LayerCollection()
+ tower_results = []
+ for tower_id in range(num_towers):
+ with tf.device(devices[tower_id]):
+ with tf.name_scope("tower%d" % tower_id):
+ with tf.variable_scope(tf.get_variable_scope(), reuse=(tower_id > 0)):
+ tf.logging.info("Building tower %d." % tower_id)
+ tower_results.append(
+ build_model(examples[tower_id], labels[tower_id], 10,
+ layer_collection))
+ losses, accuracies = zip(*tower_results)
+
+ # Average across towers.
+ loss = tf.reduce_mean(losses)
+ accuracy = tf.reduce_mean(accuracies)
+
+ # Fit model.
+
+ session_config = tf.ConfigProto(
+ allow_soft_placement=False,
+ device_count=device_count,
+ )
+
+ g_step = tf.train.get_or_create_global_step()
+ optimizer = opt.KfacOptimizer(
+ learning_rate=0.0001,
+ cov_ema_decay=0.95,
+ damping=0.001,
+ layer_collection=layer_collection,
+ placement_strategy="round_robin",
+ cov_devices=devices,
+ inv_devices=devices,
+ momentum=0.9)
+ (cov_update_thunks,
+ inv_update_thunks) = optimizer.make_vars_and_create_op_thunks()
+
+ def make_update_op(update_thunks):
+ update_ops = [thunk() for thunk in update_thunks]
+ return tf.group(*update_ops)
+
+ cov_update_op = make_update_op(cov_update_thunks)
+ with tf.control_dependencies([cov_update_op]):
+ inverse_op = tf.cond(
+ tf.equal(tf.mod(g_step, _INVERT_EVERY), 0),
+ lambda: make_update_op(inv_update_thunks), tf.no_op)
+ with tf.control_dependencies([inverse_op]):
+ train_op = optimizer.minimize(loss, global_step=g_step)
+
+ tf.logging.info("Starting training.")
+ with tf.train.MonitoredTrainingSession(config=session_config) as sess:
+ while not sess.should_stop():
+ global_step_, loss_, accuracy_, _ = sess.run(
+ [g_step, loss, accuracy, train_op])
+
+ if global_step_ % _INVERT_EVERY == 0:
+ tf.logging.info("global_step: %d | loss: %f | accuracy: %s",
+ global_step_, loss_, accuracy_)
+
+
+def train_mnist_distributed_sync_replicas(task_id,
+ is_chief,
+ num_worker_tasks,
+ num_ps_tasks,
+ master,
+ data_dir,
+ num_epochs,
+ op_strategy,
+ use_fake_data=False):
+ """Train a ConvNet on MNIST using Sync replicas optimizer.
+
+ Args:
+ task_id: int. Integer in [0, num_worker_tasks). ID for this worker.
+ is_chief: `boolean`, `True` if the worker is chief worker.
+ num_worker_tasks: int. Number of workers in this distributed training setup.
+ num_ps_tasks: int. Number of parameter servers holding variables.
+ master: string. IP and port of TensorFlow runtime process.
+ data_dir: string. Directory to read MNIST examples from.
+ num_epochs: int. Number of passes to make over the training set.
+ op_strategy: `string`, Strategy to run the covariance and inverse
+ ops. If op_strategy == `chief_worker` then covariance and inverse
+ update ops are run on chief worker otherwise they are run on dedicated
+ workers.
+
+ use_fake_data: bool. If True, generate a synthetic dataset.
+
+ Returns:
+ accuracy of model on the final minibatch of training data.
+
+ Raises:
+ ValueError: If `op_strategy` not in ["chief_worker", "dedicated_workers"].
+ """
+ # Load a dataset.
+ tf.logging.info("Loading MNIST into memory.")
+ examples, labels = mnist.load_mnist(
+ data_dir,
+ num_epochs=num_epochs,
+ batch_size=128,
+ use_fake_data=use_fake_data,
+ flatten_images=False)
+
+ # Build a ConvNet.
+ layer_collection = lc.LayerCollection()
+ with tf.device(tf.train.replica_device_setter(num_ps_tasks)):
+ loss, accuracy = build_model(
+ examples, labels, num_labels=10, layer_collection=layer_collection)
+
+ # Fit model.
+ checkpoint_dir = None if data_dir is None else os.path.join(data_dir, "kfac")
+ if op_strategy == "chief_worker":
+ return distributed_grads_only_and_ops_chief_worker(
+ task_id, is_chief, num_worker_tasks, num_ps_tasks, master,
+ checkpoint_dir, loss, accuracy, layer_collection)
+ elif op_strategy == "dedicated_workers":
+ return distributed_grads_and_ops_dedicated_workers(
+ task_id, is_chief, num_worker_tasks, num_ps_tasks, master,
+ checkpoint_dir, loss, accuracy, layer_collection)
+ else:
+ raise ValueError("Only supported op strategies are : {}, {}".format(
+ "chief_worker", "dedicated_workers"))
+
+
+if __name__ == "__main__":
+ tf.app.run()
diff --git a/tensorflow/contrib/kfac/examples/convnet_mnist_distributed_main.py b/tensorflow/contrib/kfac/examples/convnet_mnist_distributed_main.py
new file mode 100644
index 0000000000..b4c2d4a9e9
--- /dev/null
+++ b/tensorflow/contrib/kfac/examples/convnet_mnist_distributed_main.py
@@ -0,0 +1,62 @@
+# Copyright 2017 The TensorFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+r"""Train a ConvNet on MNIST using K-FAC.
+
+Distributed training with sync replicas optimizer. See
+`convnet.train_mnist_distributed_sync_replicas` for details.
+"""
+
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+
+from absl import flags
+import tensorflow as tf
+
+from tensorflow.contrib.kfac.examples import convnet
+
+FLAGS = flags.FLAGS
+flags.DEFINE_integer("task", -1, "Task identifier")
+flags.DEFINE_string("data_dir", "/tmp/mnist", "local mnist dir")
+flags.DEFINE_string(
+ "cov_inv_op_strategy", "chief_worker",
+ "In dist training mode run the cov, inv ops on chief or dedicated workers."
+)
+flags.DEFINE_string("master", "local", "Session master.")
+flags.DEFINE_integer("ps_tasks", 2,
+ "Number of tasks in the parameter server job.")
+flags.DEFINE_integer("replicas_to_aggregate", 5,
+ "Number of replicas to aggregate.")
+flags.DEFINE_integer("worker_replicas", 5, "Number of replicas in worker job.")
+flags.DEFINE_integer("num_epochs", None, "Number of epochs.")
+
+
+def _is_chief():
+ """Determines whether a job is the chief worker."""
+ if "chief_worker" in FLAGS.brain_jobs:
+ return FLAGS.brain_job_name == "chief_worker"
+ else:
+ return FLAGS.task == 0
+
+
+def main(unused_argv):
+ _ = unused_argv
+ convnet.train_mnist_distributed_sync_replicas(
+ FLAGS.task, _is_chief(), FLAGS.worker_replicas, FLAGS.ps_tasks,
+ FLAGS.master, FLAGS.data_dir, FLAGS.num_epochs, FLAGS.cov_inv_op_strategy)
+
+if __name__ == "__main__":
+ tf.app.run(main=main)
diff --git a/tensorflow/contrib/kfac/examples/convnet_mnist_multi_tower_main.py b/tensorflow/contrib/kfac/examples/convnet_mnist_multi_tower_main.py
new file mode 100644
index 0000000000..4249bf8a8d
--- /dev/null
+++ b/tensorflow/contrib/kfac/examples/convnet_mnist_multi_tower_main.py
@@ -0,0 +1,48 @@
+# Copyright 2017 The TensorFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+r"""Train a ConvNet on MNIST using K-FAC.
+
+Multi tower training mode. See `convnet.train_mnist_multitower` for details.
+"""
+
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+
+from absl import flags
+import tensorflow as tf
+
+from tensorflow.contrib.kfac.examples import convnet
+
+FLAGS = flags.FLAGS
+flags.DEFINE_string("data_dir", "/tmp/multitower_1/mnist", "local mnist dir")
+flags.DEFINE_integer("num_towers", 2,
+ "Number of towers for multi tower training.")
+
+
+def main(unused_argv):
+ _ = unused_argv
+ assert FLAGS.num_towers > 1
+ devices = ["/gpu:{}".format(tower_id) for tower_id in range(FLAGS.num_towers)]
+ convnet.train_mnist_multitower(
+ FLAGS.data_dir,
+ num_epochs=200,
+ num_towers=FLAGS.num_towers,
+ devices=devices)
+
+
+if __name__ == "__main__":
+ tf.app.run(main=main)
diff --git a/tensorflow/contrib/kfac/examples/convnet_mnist_single_main.py b/tensorflow/contrib/kfac/examples/convnet_mnist_single_main.py
new file mode 100644
index 0000000000..2c1f099360
--- /dev/null
+++ b/tensorflow/contrib/kfac/examples/convnet_mnist_single_main.py
@@ -0,0 +1,39 @@
+# Copyright 2017 The TensorFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+r"""Train a ConvNet on MNIST using K-FAC.
+
+Train on single machine. See `convnet.train_mnist_single_machine` for details.
+"""
+
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+
+from absl import flags
+import tensorflow as tf
+
+from tensorflow.contrib.kfac.examples import convnet
+
+FLAGS = flags.FLAGS
+flags.DEFINE_string("data_dir", "/tmp/mnist", "local mnist dir")
+
+
+def main(unused_argv):
+ convnet.train_mnist_single_machine(FLAGS.data_dir, num_epochs=200)
+
+
+if __name__ == "__main__":
+ tf.app.run(main=main)
diff --git a/tensorflow/contrib/kfac/examples/mlp.py b/tensorflow/contrib/kfac/examples/mlp.py
new file mode 100644
index 0000000000..ea2b252a05
--- /dev/null
+++ b/tensorflow/contrib/kfac/examples/mlp.py
@@ -0,0 +1,354 @@
+# Copyright 2017 The TensorFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+r"""Train an MLP on MNIST using K-FAC.
+
+This library fits a 3-layer, tanh-activated MLP on MNIST using K-FAC. After
+~25k steps, this should reach perfect accuracy on the training set.
+"""
+
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+import tensorflow as tf
+
+from tensorflow.contrib.kfac.examples import mnist
+
+lc = tf.contrib.kfac.layer_collection
+opt = tf.contrib.kfac.optimizer
+
+__all__ = [
+ "fc_layer",
+ "train_mnist",
+ "train_mnist_multitower",
+]
+
+
+def fc_layer(layer_id, inputs, output_size):
+ """Builds a fully connected layer.
+
+ Args:
+ layer_id: int. Integer ID for this layer's variables.
+ inputs: Tensor of shape [num_examples, input_size]. Each row corresponds
+ to a single example.
+ output_size: int. Number of output dimensions after fully connected layer.
+
+ Returns:
+ preactivations: Tensor of shape [num_examples, output_size]. Values of the
+ layer immediately before the activation function.
+ activations: Tensor of shape [num_examples, output_size]. Values of the
+ layer immediately after the activation function.
+ params: Tuple of (weights, bias), parameters for this layer.
+ """
+ # TODO(b/67004004): Delete this function and rely on tf.layers exclusively.
+ layer = tf.layers.Dense(
+ output_size,
+ kernel_initializer=tf.random_normal_initializer(),
+ name="fc_%d" % layer_id)
+ preactivations = layer(inputs)
+ activations = tf.nn.tanh(preactivations)
+
+ # layer.weights is a list. This converts it a (hashable) tuple.
+ return preactivations, activations, (layer.kernel, layer.bias)
+
+
+def build_model(examples, labels, num_labels, layer_collection):
+ """Builds an MLP classification model.
+
+ Args:
+ examples: Tensor of shape [num_examples, num_features]. Represents inputs of
+ model.
+ labels: Tensor of shape [num_examples]. Contains integer IDs to be predicted
+ by softmax for each example.
+ num_labels: int. Number of distinct values 'labels' can take on.
+ layer_collection: LayerCollection instance describing model architecture.
+
+ Returns:
+ loss: 0-D Tensor representing loss to be minimized.
+ accuracy: 0-D Tensor representing model's accuracy.
+ """
+ # Build an MLP. For each layer, we'll keep track of the preactivations,
+ # activations, weights, and bias.
+ pre0, act0, params0 = fc_layer(layer_id=0, inputs=examples, output_size=128)
+ pre1, act1, params1 = fc_layer(layer_id=1, inputs=act0, output_size=64)
+ pre2, act2, params2 = fc_layer(layer_id=2, inputs=act1, output_size=32)
+ logits, _, params3 = fc_layer(layer_id=3, inputs=act2, output_size=num_labels)
+ loss = tf.reduce_mean(
+ tf.nn.sparse_softmax_cross_entropy_with_logits(
+ labels=labels, logits=logits))
+ accuracy = tf.reduce_mean(
+ tf.cast(tf.equal(labels, tf.argmax(logits, axis=1)), dtype=tf.float32))
+
+ # Register parameters. K-FAC needs to know about the inputs, outputs, and
+ # parameters of each layer and the logits powering the posterior probability
+ # over classes.
+ tf.logging.info("Building LayerCollection.")
+ layer_collection.register_fully_connected(params0, examples, pre0)
+ layer_collection.register_fully_connected(params1, act0, pre1)
+ layer_collection.register_fully_connected(params2, act1, pre2)
+ layer_collection.register_fully_connected(params3, act2, logits)
+ layer_collection.register_categorical_predictive_distribution(
+ logits, name="logits")
+
+ return loss, accuracy
+
+
+def minimize(loss, accuracy, layer_collection, num_towers, session_config=None):
+ """Minimize 'loss' with KfacOptimizer.
+
+ Args:
+ loss: 0-D Tensor. Loss to be minimized.
+ accuracy: 0-D Tensor. Accuracy of classifier on current minibatch.
+ layer_collection: LayerCollection instance. Describes layers in model.
+ num_towers: int. Number of CPUs to split minibatch across.
+ session_config: tf.ConfigProto. Configuration for tf.Session().
+
+ Returns:
+ accuracy of classifier on final minibatch.
+ """
+ devices = tuple("/cpu:%d" % tower_id for tower_id in range(num_towers))
+
+ # Train with K-FAC. We'll use a decreasing learning rate that's cut in 1/2
+ # every 10k iterations.
+ tf.logging.info("Building KFAC Optimizer.")
+ global_step = tf.train.get_or_create_global_step()
+ optimizer = opt.KfacOptimizer(
+ learning_rate=tf.train.exponential_decay(
+ 0.00002, global_step, 10000, 0.5, staircase=True),
+ cov_ema_decay=0.95,
+ damping=0.0005,
+ layer_collection=layer_collection,
+ momentum=0.99,
+ placement_strategy="round_robin",
+ cov_devices=devices,
+ inv_devices=devices)
+
+ (cov_update_thunks,
+ inv_update_thunks) = optimizer.make_vars_and_create_op_thunks()
+
+ def make_update_op(update_thunks):
+ update_ops = [thunk() for thunk in update_thunks]
+ return tf.group(*update_ops)
+
+ # TODO(b/78537047): change (some) examples to use PeriodicInvCovUpdateKfacOpt
+ # once that gets moved over? Could still leave more advanced examples as they
+ # are (e.g. train_mnist_estimator in this file)
+
+ cov_update_op = make_update_op(cov_update_thunks)
+ with tf.control_dependencies([cov_update_op]):
+ # We update the inverses only every 20 iterations.
+ inverse_op = tf.cond(
+ tf.equal(tf.mod(global_step, 100), 0),
+ lambda: make_update_op(inv_update_thunks), tf.no_op)
+ with tf.control_dependencies([inverse_op]):
+ train_op = optimizer.minimize(loss, global_step=global_step)
+
+ tf.logging.info("Starting training.")
+ with tf.train.MonitoredTrainingSession(config=session_config) as sess:
+ while not sess.should_stop():
+ global_step_, loss_, accuracy_, _ = sess.run(
+ [global_step, loss, accuracy, train_op])
+
+ if global_step_ % 100 == 0:
+ tf.logging.info("global_step: %d | loss: %f | accuracy: %f",
+ global_step_, loss_, accuracy_)
+
+ return accuracy_
+
+
+def train_mnist(data_dir, num_epochs, use_fake_data=False):
+ """Train an MLP on MNIST.
+
+ Args:
+ data_dir: string. Directory to read MNIST examples from.
+ num_epochs: int. Number of passes to make over the training set.
+ use_fake_data: bool. If True, generate a synthetic dataset.
+
+ Returns:
+ accuracy of model on the final minibatch of training data.
+ """
+ # Load a dataset.
+ tf.logging.info("Loading MNIST into memory.")
+ examples, labels = mnist.load_mnist(
+ data_dir,
+ num_epochs=num_epochs,
+ batch_size=64,
+ flatten_images=True,
+ use_fake_data=use_fake_data)
+
+ # Build an MLP. The model's layers will be added to the LayerCollection.
+ tf.logging.info("Building model.")
+ layer_collection = lc.LayerCollection()
+ loss, accuracy = build_model(examples, labels, 10, layer_collection)
+
+ # Fit model.
+ minimize(loss, accuracy, layer_collection, 1)
+
+
+def train_mnist_multitower(data_dir,
+ num_epochs,
+ num_towers,
+ use_fake_data=False):
+ """Train an MLP on MNIST, splitting the minibatch across multiple towers.
+
+ Args:
+ data_dir: string. Directory to read MNIST examples from.
+ num_epochs: int. Number of passes to make over the training set.
+ num_towers: int. Number of CPUs to split minibatch across.
+ use_fake_data: bool. If True, generate a synthetic dataset.
+
+ Returns:
+ accuracy of model on the final minibatch of training data.
+ """
+ # Load a dataset.
+ tower_batch_size = 64
+ batch_size = tower_batch_size * num_towers
+ tf.logging.info(
+ ("Loading MNIST into memory. Using batch_size = %d = %d towers * %d "
+ "tower batch size.") % (batch_size, num_towers, tower_batch_size))
+ examples, labels = mnist.load_mnist(
+ data_dir,
+ num_epochs=num_epochs,
+ batch_size=batch_size,
+ flatten_images=True,
+ use_fake_data=use_fake_data)
+
+ # Split minibatch across towers.
+ examples = tf.split(examples, num_towers)
+ labels = tf.split(labels, num_towers)
+
+ # Build an MLP. Each tower's layers will be added to the LayerCollection.
+ layer_collection = lc.LayerCollection()
+ tower_results = []
+ for tower_id in range(num_towers):
+ with tf.device("/cpu:%d" % tower_id):
+ with tf.name_scope("tower%d" % tower_id):
+ with tf.variable_scope(tf.get_variable_scope(), reuse=(tower_id > 0)):
+ tf.logging.info("Building tower %d." % tower_id)
+ tower_results.append(
+ build_model(examples[tower_id], labels[tower_id], 10,
+ layer_collection))
+ losses, accuracies = zip(*tower_results)
+
+ # Average across towers.
+ loss = tf.reduce_mean(losses)
+ accuracy = tf.reduce_mean(accuracies)
+
+ # Fit model.
+ session_config = tf.ConfigProto(
+ allow_soft_placement=False, device_count={
+ "CPU": num_towers
+ })
+ return minimize(
+ loss, accuracy, layer_collection, num_towers,
+ session_config=session_config)
+
+
+def train_mnist_estimator(data_dir, num_epochs, use_fake_data=False):
+ """Train an MLP on MNIST using tf.estimator.
+
+ Args:
+ data_dir: string. Directory to read MNIST examples from.
+ num_epochs: int. Number of passes to make over the training set.
+ use_fake_data: bool. If True, generate a synthetic dataset.
+
+ Returns:
+ accuracy of model on the final minibatch of training data.
+ """
+
+ # Load a dataset.
+ def input_fn():
+ tf.logging.info("Loading MNIST into memory.")
+ return mnist.load_mnist(
+ data_dir,
+ num_epochs=num_epochs,
+ batch_size=64,
+ flatten_images=True,
+ use_fake_data=use_fake_data)
+
+ def model_fn(features, labels, mode, params):
+ """Model function for MLP trained with K-FAC.
+
+ Args:
+ features: Tensor of shape [batch_size, input_size]. Input features.
+ labels: Tensor of shape [batch_size]. Target labels for training.
+ mode: tf.estimator.ModeKey. Must be TRAIN.
+ params: ignored.
+
+ Returns:
+ EstimatorSpec for training.
+
+ Raises:
+ ValueError: If 'mode' is anything other than TRAIN.
+ """
+ del params
+
+ if mode != tf.estimator.ModeKeys.TRAIN:
+ raise ValueError("Only training is supposed with this API.")
+
+ # Build a ConvNet.
+ layer_collection = lc.LayerCollection()
+ loss, accuracy = build_model(
+ features, labels, num_labels=10, layer_collection=layer_collection)
+
+ # Train with K-FAC.
+ global_step = tf.train.get_or_create_global_step()
+ optimizer = opt.KfacOptimizer(
+ learning_rate=tf.train.exponential_decay(
+ 0.00002, global_step, 10000, 0.5, staircase=True),
+ cov_ema_decay=0.95,
+ damping=0.0001,
+ layer_collection=layer_collection,
+ momentum=0.99)
+
+ (cov_update_thunks,
+ inv_update_thunks) = optimizer.make_vars_and_create_op_thunks()
+
+ def make_update_op(update_thunks):
+ update_ops = [thunk() for thunk in update_thunks]
+ return tf.group(*update_ops)
+
+ def make_batch_executed_op(update_thunks, batch_size=1):
+ return tf.group(*tf.contrib.kfac.utils.batch_execute(
+ global_step, update_thunks, batch_size=batch_size))
+
+ # Run cov_update_op every step. Run 1 inv_update_ops per step.
+ cov_update_op = make_update_op(cov_update_thunks)
+ with tf.control_dependencies([cov_update_op]):
+ # But make sure to execute all the inverse ops on the first step
+ inverse_op = tf.cond(tf.equal(global_step, 0),
+ lambda: make_update_op(inv_update_thunks),
+ lambda: make_batch_executed_op(inv_update_thunks))
+ with tf.control_dependencies([inverse_op]):
+ train_op = optimizer.minimize(loss, global_step=global_step)
+
+ # Print metrics every 5 sec.
+ hooks = [
+ tf.train.LoggingTensorHook(
+ {
+ "loss": loss,
+ "accuracy": accuracy
+ }, every_n_secs=5),
+ ]
+ return tf.estimator.EstimatorSpec(
+ mode=mode, loss=loss, train_op=train_op, training_hooks=hooks)
+
+ run_config = tf.estimator.RunConfig(
+ model_dir="/tmp/mnist", save_checkpoints_steps=1, keep_checkpoint_max=100)
+
+ # Train until input_fn() is empty with Estimator. This is a prerequisite for
+ # TPU compatibility.
+ estimator = tf.estimator.Estimator(model_fn=model_fn, config=run_config)
+ estimator.train(input_fn=input_fn)
diff --git a/tensorflow/contrib/kfac/examples/mlp_mnist_main.py b/tensorflow/contrib/kfac/examples/mlp_mnist_main.py
new file mode 100644
index 0000000000..9c34ade1d2
--- /dev/null
+++ b/tensorflow/contrib/kfac/examples/mlp_mnist_main.py
@@ -0,0 +1,64 @@
+# Copyright 2017 The TensorFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+r"""Train an MLP on MNIST using K-FAC.
+
+See mlp.py for details.
+"""
+
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+import argparse
+import sys
+
+import tensorflow as tf
+
+from tensorflow.contrib.kfac.examples import mlp
+
+FLAGS = None
+
+
+def main(argv):
+ _ = argv
+ if FLAGS.use_estimator:
+ if FLAGS.num_towers != 1:
+ raise ValueError("Only 1 device supported in tf.estimator example.")
+ mlp.train_mnist_estimator(FLAGS.data_dir, num_epochs=200)
+ elif FLAGS.num_towers > 1:
+ mlp.train_mnist_multitower(
+ FLAGS.data_dir, num_epochs=200, num_towers=FLAGS.num_towers)
+ else:
+ mlp.train_mnist(FLAGS.data_dir, num_epochs=200)
+
+
+if __name__ == "__main__":
+ parser = argparse.ArgumentParser()
+ parser.add_argument(
+ "--data_dir",
+ type=str,
+ default="/tmp/mnist",
+ help="Directory to store dataset in.")
+ parser.add_argument(
+ "--num_towers",
+ type=int,
+ default=1,
+ help="Number of CPUs to split minibatch across.")
+ parser.add_argument(
+ "--use_estimator",
+ action="store_true",
+ help="Use tf.estimator API to train.")
+ FLAGS, unparsed = parser.parse_known_args()
+ tf.app.run(main=main, argv=[sys.argv[0]] + unparsed)
diff --git a/tensorflow/contrib/kfac/examples/mnist.py b/tensorflow/contrib/kfac/examples/mnist.py
new file mode 100644
index 0000000000..547c4ab25d
--- /dev/null
+++ b/tensorflow/contrib/kfac/examples/mnist.py
@@ -0,0 +1,69 @@
+# Copyright 2017 The TensorFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+"""Utilities for loading MNIST into TensorFlow."""
+
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+import numpy as np
+import tensorflow as tf
+
+__all__ = [
+ 'load_mnist',
+]
+
+
+def load_mnist(data_dir,
+ num_epochs,
+ batch_size,
+ flatten_images=True,
+ use_fake_data=False):
+ """Loads MNIST dataset into memory.
+
+ Args:
+ data_dir: string. Directory to read MNIST examples from.
+ num_epochs: int. Number of passes to make over the dataset.
+ batch_size: int. Number of examples per minibatch.
+ flatten_images: bool. If True, [28, 28, 1]-shaped images are flattened into
+ [784]-shaped vectors.
+ use_fake_data: bool. If True, generate a synthetic dataset rather than
+ reading MNIST in.
+
+ Returns:
+ examples: Tensor of shape [batch_size, 784] if 'flatten_images' is
+ True, else [batch_size, 28, 28, 1]. Each row is one example.
+ Values in [0, 1].
+ labels: Tensor of shape [batch_size]. Indices of integer corresponding to
+ each example. Values in {0...9}.
+ """
+ if use_fake_data:
+ rng = np.random.RandomState(42)
+ num_examples = batch_size * 4
+ images = rng.rand(num_examples, 28 * 28)
+ if not flatten_images:
+ images = np.reshape(images, [num_examples, 28, 28, 1])
+ labels = rng.randint(10, size=num_examples)
+ else:
+ mnist_data = tf.contrib.learn.datasets.mnist.read_data_sets(
+ data_dir, reshape=flatten_images)
+ num_examples = len(mnist_data.train.labels)
+ images = mnist_data.train.images
+ labels = mnist_data.train.labels
+
+ dataset = tf.data.Dataset.from_tensor_slices((np.asarray(
+ images, dtype=np.float32), np.asarray(labels, dtype=np.int64)))
+ return (dataset.repeat(num_epochs).shuffle(num_examples).batch(batch_size)
+ .make_one_shot_iterator().get_next())
diff --git a/tensorflow/contrib/kfac/examples/tests/BUILD b/tensorflow/contrib/kfac/examples/tests/BUILD
new file mode 100644
index 0000000000..ede7f183fe
--- /dev/null
+++ b/tensorflow/contrib/kfac/examples/tests/BUILD
@@ -0,0 +1,52 @@
+package(default_visibility = ["//visibility:private"])
+
+licenses(["notice"]) # Apache 2.0
+
+exports_files(["LICENSE"])
+
+load("//tensorflow:tensorflow.bzl", "py_test")
+
+py_test(
+ name = "mlp_test",
+ size = "large",
+ srcs = ["mlp_test.py"],
+ srcs_version = "PY2AND3",
+ tags = [
+ "no_pip",
+ "notsan",
+ ],
+ deps = [
+ "//tensorflow:tensorflow_py",
+ "//tensorflow/contrib/kfac/examples:mlp",
+ "//third_party/py/numpy",
+ ],
+)
+
+py_test(
+ name = "convnet_test",
+ size = "large",
+ srcs = ["convnet_test.py"],
+ srcs_version = "PY2AND3",
+ tags = [
+ "no_pip",
+ "notsan",
+ ],
+ deps = [
+ "//tensorflow:tensorflow_py",
+ "//tensorflow/contrib/kfac",
+ "//tensorflow/contrib/kfac/examples:convnet",
+ "//third_party/py/numpy",
+ ],
+)
+
+py_test(
+ name = "mnist_test",
+ srcs = ["mnist_test.py"],
+ srcs_version = "PY2AND3",
+ tags = ["no_pip"],
+ deps = [
+ "//tensorflow:tensorflow_py",
+ "//tensorflow/contrib/kfac/examples:mnist",
+ "//third_party/py/numpy",
+ ],
+)
diff --git a/tensorflow/contrib/kfac/examples/tests/convnet_test.py b/tensorflow/contrib/kfac/examples/tests/convnet_test.py
new file mode 100644
index 0000000000..adecda7166
--- /dev/null
+++ b/tensorflow/contrib/kfac/examples/tests/convnet_test.py
@@ -0,0 +1,166 @@
+# Copyright 2017 The TensorFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+"""Tests for convnet.py."""
+
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+import numpy as np
+import tensorflow as tf
+
+from tensorflow.contrib.kfac import layer_collection as lc
+from tensorflow.contrib.kfac.examples import convnet
+
+
+class ConvNetTest(tf.test.TestCase):
+
+ def testConvLayer(self):
+ with tf.Graph().as_default():
+ pre, act, (w, b) = convnet.conv_layer(
+ layer_id=1,
+ inputs=tf.zeros([5, 3, 3, 2]),
+ kernel_size=3,
+ out_channels=5)
+ self.assertShapeEqual(np.zeros([5, 3, 3, 5]), pre)
+ self.assertShapeEqual(np.zeros([5, 3, 3, 5]), act)
+ self.assertShapeEqual(np.zeros([3, 3, 2, 5]), tf.convert_to_tensor(w))
+ self.assertShapeEqual(np.zeros([5]), tf.convert_to_tensor(b))
+ self.assertIsInstance(w, tf.Variable)
+ self.assertIsInstance(b, tf.Variable)
+ self.assertIn("conv_1", w.op.name)
+ self.assertIn("conv_1", b.op.name)
+
+ def testMaxPoolLayer(self):
+ with tf.Graph().as_default():
+ act = convnet.max_pool_layer(
+ layer_id=1, inputs=tf.zeros([5, 6, 6, 2]), kernel_size=5, stride=3)
+ self.assertShapeEqual(np.zeros([5, 2, 2, 2]), act)
+ self.assertEqual(act.op.name, "pool_1/pool")
+
+ def testLinearLayer(self):
+ with tf.Graph().as_default():
+ act, (w, b) = convnet.linear_layer(
+ layer_id=1, inputs=tf.zeros([5, 20]), output_size=5)
+ self.assertShapeEqual(np.zeros([5, 5]), act)
+ self.assertShapeEqual(np.zeros([20, 5]), tf.convert_to_tensor(w))
+ self.assertShapeEqual(np.zeros([5]), tf.convert_to_tensor(b))
+ self.assertIsInstance(w, tf.Variable)
+ self.assertIsInstance(b, tf.Variable)
+ self.assertIn("fc_1", w.op.name)
+ self.assertIn("fc_1", b.op.name)
+
+ def testBuildModel(self):
+ with tf.Graph().as_default():
+ x = tf.placeholder(tf.float32, [None, 6, 6, 3])
+ y = tf.placeholder(tf.int64, [None])
+ layer_collection = lc.LayerCollection()
+ loss, accuracy = convnet.build_model(
+ x, y, num_labels=5, layer_collection=layer_collection)
+
+ # Ensure layers and logits were registered.
+ self.assertEqual(len(layer_collection.fisher_blocks), 3)
+ self.assertEqual(len(layer_collection.losses), 1)
+
+ # Ensure inference doesn't crash.
+ with self.test_session() as sess:
+ sess.run(tf.global_variables_initializer())
+ feed_dict = {
+ x: np.random.randn(10, 6, 6, 3).astype(np.float32),
+ y: np.random.randint(5, size=10).astype(np.int64),
+ }
+ sess.run([loss, accuracy], feed_dict=feed_dict)
+
+ def _build_toy_problem(self):
+ """Construct a toy linear regression problem.
+
+ Initial loss should be,
+ 2.5 = 0.5 * (1^2 + 2^2)
+
+ Returns:
+ loss: 0-D Tensor representing loss to be minimized.
+ accuracy: 0-D Tensors representing model accuracy.
+ layer_collection: LayerCollection instance describing model architecture.
+ """
+ x = np.asarray([[1.], [2.]]).astype(np.float32)
+ y = np.asarray([1., 2.]).astype(np.float32)
+ x, y = (tf.data.Dataset.from_tensor_slices((x, y))
+ .repeat(100).batch(2).make_one_shot_iterator().get_next())
+ w = tf.get_variable("w", shape=[1, 1], initializer=tf.zeros_initializer())
+ y_hat = tf.matmul(x, w)
+ loss = tf.reduce_mean(0.5 * tf.square(y_hat - y))
+ accuracy = loss
+
+ layer_collection = lc.LayerCollection()
+ layer_collection.register_fully_connected(params=w, inputs=x, outputs=y_hat)
+ layer_collection.register_normal_predictive_distribution(y_hat)
+
+ return loss, accuracy, layer_collection
+
+ def testMinimizeLossSingleMachine(self):
+ with tf.Graph().as_default():
+ loss, accuracy, layer_collection = self._build_toy_problem()
+ accuracy_ = convnet.minimize_loss_single_machine(
+ loss, accuracy, layer_collection, device="/cpu:0")
+ self.assertLess(accuracy_, 2.0)
+
+ def testMinimizeLossDistributed(self):
+ with tf.Graph().as_default():
+ loss, accuracy, layer_collection = self._build_toy_problem()
+ accuracy_ = convnet.distributed_grads_only_and_ops_chief_worker(
+ task_id=0,
+ is_chief=True,
+ num_worker_tasks=1,
+ num_ps_tasks=0,
+ master="",
+ checkpoint_dir=None,
+ loss=loss,
+ accuracy=accuracy,
+ layer_collection=layer_collection)
+ self.assertLess(accuracy_, 2.0)
+
+ def testTrainMnistSingleMachine(self):
+ with tf.Graph().as_default():
+ # Ensure model training doesn't crash.
+ #
+ # Ideally, we should check that accuracy increases as the model converges,
+ # but there are too few parameters for the model to effectively memorize
+ # the training set the way an MLP can.
+ convnet.train_mnist_single_machine(
+ data_dir=None, num_epochs=1, use_fake_data=True, device="/cpu:0")
+
+ def testTrainMnistMultitower(self):
+ with tf.Graph().as_default():
+ # Ensure model training doesn't crash.
+ convnet.train_mnist_multitower(
+ data_dir=None, num_epochs=1, num_towers=2, use_fake_data=True)
+
+ def testTrainMnistDistributed(self):
+ with tf.Graph().as_default():
+ # Ensure model training doesn't crash.
+ convnet.train_mnist_distributed_sync_replicas(
+ task_id=0,
+ is_chief=True,
+ num_worker_tasks=1,
+ num_ps_tasks=0,
+ master="",
+ data_dir=None,
+ num_epochs=2,
+ op_strategy="chief_worker",
+ use_fake_data=True)
+
+
+if __name__ == "__main__":
+ tf.test.main()
diff --git a/tensorflow/contrib/kfac/examples/tests/mlp_test.py b/tensorflow/contrib/kfac/examples/tests/mlp_test.py
new file mode 100644
index 0000000000..22da6c29f1
--- /dev/null
+++ b/tensorflow/contrib/kfac/examples/tests/mlp_test.py
@@ -0,0 +1,63 @@
+# Copyright 2017 The TensorFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+"""Tests for mlp.py."""
+
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+import numpy as np
+import tensorflow as tf
+
+from tensorflow.contrib.kfac.examples import mlp
+
+
+class MlpTest(tf.test.TestCase):
+
+ def testFcLayer(self):
+ with tf.Graph().as_default():
+ pre, act, (w, b) = mlp.fc_layer(
+ layer_id=1, inputs=tf.zeros([5, 3]), output_size=10)
+ self.assertShapeEqual(np.zeros([5, 10]), pre)
+ self.assertShapeEqual(np.zeros([5, 10]), act)
+ self.assertShapeEqual(np.zeros([3, 10]), tf.convert_to_tensor(w))
+ self.assertShapeEqual(np.zeros([10]), tf.convert_to_tensor(b))
+ self.assertIsInstance(w, tf.Variable)
+ self.assertIsInstance(b, tf.Variable)
+ self.assertIn("fc_1/", w.op.name)
+ self.assertIn("fc_1/", b.op.name)
+
+ def testTrainMnist(self):
+ with tf.Graph().as_default():
+ # Ensure model training doesn't crash.
+ #
+ # Ideally, we should check that accuracy increases as the model converges,
+ # but that takes a non-trivial amount of compute.
+ mlp.train_mnist(data_dir=None, num_epochs=1, use_fake_data=True)
+
+ def testTrainMnistMultitower(self):
+ with tf.Graph().as_default():
+ # Ensure model training doesn't crash.
+ mlp.train_mnist_multitower(
+ data_dir=None, num_epochs=1, num_towers=2, use_fake_data=True)
+
+ def testTrainMnistEstimator(self):
+ with tf.Graph().as_default():
+ # Ensure model training doesn't crash.
+ mlp.train_mnist_estimator(data_dir=None, num_epochs=1, use_fake_data=True)
+
+
+if __name__ == "__main__":
+ tf.test.main()
diff --git a/tensorflow/contrib/kfac/examples/tests/mnist_test.py b/tensorflow/contrib/kfac/examples/tests/mnist_test.py
new file mode 100644
index 0000000000..92f8462357
--- /dev/null
+++ b/tensorflow/contrib/kfac/examples/tests/mnist_test.py
@@ -0,0 +1,72 @@
+# Copyright 2017 The TensorFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+"""Tests for mnist.py."""
+
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+import numpy as np
+import tensorflow as tf
+
+from tensorflow.contrib.kfac.examples import mnist
+
+
+class MnistTest(tf.test.TestCase):
+
+ def testValues(self):
+ """Ensure values are in their expected range."""
+ with tf.Graph().as_default():
+ examples, labels = mnist.load_mnist(
+ data_dir=None, num_epochs=1, batch_size=64, use_fake_data=True)
+
+ with self.test_session() as sess:
+ examples_, labels_ = sess.run([examples, labels])
+ self.assertTrue(np.all((0 <= examples_) & (examples_ < 1)))
+ self.assertTrue(np.all((0 <= labels_) & (labels_ < 10)))
+
+ def testFlattenedShapes(self):
+ """Ensure images are flattened into their appropriate shape."""
+ with tf.Graph().as_default():
+ examples, labels = mnist.load_mnist(
+ data_dir=None,
+ num_epochs=1,
+ batch_size=64,
+ flatten_images=True,
+ use_fake_data=True)
+
+ with self.test_session() as sess:
+ examples_, labels_ = sess.run([examples, labels])
+ self.assertEqual(examples_.shape, (64, 784))
+ self.assertEqual(labels_.shape, (64,))
+
+ def testNotFlattenedShapes(self):
+ """Ensure non-flattened images are their appropriate shape."""
+ with tf.Graph().as_default():
+ examples, labels = mnist.load_mnist(
+ data_dir=None,
+ num_epochs=1,
+ batch_size=64,
+ flatten_images=False,
+ use_fake_data=True)
+
+ with self.test_session() as sess:
+ examples_, labels_ = sess.run([examples, labels])
+ self.assertEqual(examples_.shape, (64, 28, 28, 1))
+ self.assertEqual(labels_.shape, (64,))
+
+
+if __name__ == '__main__':
+ tf.test.main()
diff --git a/tensorflow/contrib/kfac/g3doc/autoencoder.png b/tensorflow/contrib/kfac/g3doc/autoencoder.png
new file mode 100644
index 0000000000..20f93c7703
--- /dev/null
+++ b/tensorflow/contrib/kfac/g3doc/autoencoder.png
Binary files differ
diff --git a/tensorflow/contrib/kfac/python/kernel_tests/BUILD b/tensorflow/contrib/kfac/python/kernel_tests/BUILD
new file mode 100644
index 0000000000..6e4a8d71ba
--- /dev/null
+++ b/tensorflow/contrib/kfac/python/kernel_tests/BUILD
@@ -0,0 +1,160 @@
+package(default_visibility = ["//visibility:private"])
+
+licenses(["notice"]) # Apache 2.0
+
+exports_files(["LICENSE"])
+
+load("//tensorflow:tensorflow.bzl", "py_test")
+
+py_test(
+ name = "estimator_test",
+ srcs = ["estimator_test.py"],
+ srcs_version = "PY2AND3",
+ deps = [
+ "//tensorflow/contrib/kfac/python/ops:fisher_estimator",
+ "//tensorflow/contrib/kfac/python/ops:layer_collection",
+ "//tensorflow/contrib/kfac/python/ops:utils",
+ "//tensorflow/python:array_ops",
+ "//tensorflow/python:client_testlib",
+ "//tensorflow/python:constant_op",
+ "//tensorflow/python:control_flow_ops",
+ "//tensorflow/python:dtypes",
+ "//tensorflow/python:framework_ops",
+ "//tensorflow/python:init_ops",
+ "//tensorflow/python:linalg_ops",
+ "//tensorflow/python:math_ops",
+ "//tensorflow/python:random_ops",
+ "//tensorflow/python:training",
+ "//tensorflow/python:variable_scope",
+ "//tensorflow/python:variables",
+ "//third_party/py/numpy",
+ ],
+)
+
+py_test(
+ name = "fisher_factors_test",
+ srcs = ["fisher_factors_test.py"],
+ srcs_version = "PY2AND3",
+ deps = [
+ "//tensorflow/contrib/kfac/python/ops:fisher_blocks",
+ "//tensorflow/contrib/kfac/python/ops:fisher_factors",
+ "//tensorflow/python:array_ops",
+ "//tensorflow/python:client_testlib",
+ "//tensorflow/python:constant_op",
+ "//tensorflow/python:dtypes",
+ "//tensorflow/python:framework_ops",
+ "//tensorflow/python:gradients",
+ "//tensorflow/python:math_ops",
+ "//tensorflow/python:random_seed",
+ "//tensorflow/python:variables",
+ "//third_party/py/numpy",
+ ],
+)
+
+py_test(
+ name = "fisher_blocks_test",
+ srcs = ["fisher_blocks_test.py"],
+ srcs_version = "PY2AND3",
+ deps = [
+ "//tensorflow/contrib/kfac/python/ops:fisher_blocks",
+ "//tensorflow/contrib/kfac/python/ops:layer_collection",
+ "//tensorflow/contrib/kfac/python/ops:linear_operator",
+ "//tensorflow/contrib/kfac/python/ops:utils",
+ "//tensorflow/python:array_ops",
+ "//tensorflow/python:client_testlib",
+ "//tensorflow/python:framework_ops",
+ "//tensorflow/python:math_ops",
+ "//tensorflow/python:random_ops",
+ "//tensorflow/python:random_seed",
+ "//tensorflow/python:state_ops",
+ "//tensorflow/python:variables",
+ "//third_party/py/numpy",
+ ],
+)
+
+py_test(
+ name = "layer_collection_test",
+ srcs = ["layer_collection_test.py"],
+ srcs_version = "PY2AND3",
+ deps = [
+ "//tensorflow/contrib/kfac/python/ops:fisher_blocks",
+ "//tensorflow/contrib/kfac/python/ops:fisher_factors",
+ "//tensorflow/contrib/kfac/python/ops:layer_collection",
+ "//tensorflow/python:array_ops",
+ "//tensorflow/python:client_testlib",
+ "//tensorflow/python:dtypes",
+ "//tensorflow/python:framework_ops",
+ "//tensorflow/python:linalg_ops",
+ "//tensorflow/python:math_ops",
+ "//tensorflow/python:random_ops",
+ "//tensorflow/python:random_seed",
+ "//tensorflow/python:variable_scope",
+ ],
+)
+
+py_test(
+ name = "optimizer_test",
+ srcs = ["optimizer_test.py"],
+ srcs_version = "PY2AND3",
+ deps = [
+ "//tensorflow/contrib/kfac/python/ops:fisher_factors",
+ "//tensorflow/contrib/kfac/python/ops:kfac_optimizer",
+ "//tensorflow/contrib/kfac/python/ops:layer_collection",
+ "//tensorflow/python:array_ops",
+ "//tensorflow/python:client_testlib",
+ "//tensorflow/python:framework_ops",
+ "//tensorflow/python:init_ops",
+ "//tensorflow/python:math_ops",
+ "//tensorflow/python:nn",
+ "//tensorflow/python:variable_scope",
+ "//tensorflow/python:variables",
+ "//third_party/py/numpy",
+ ],
+)
+
+py_test(
+ name = "utils_test",
+ srcs = ["utils_test.py"],
+ srcs_version = "PY2AND3",
+ tags = ["no_windows"], # TODO: needs investigation on Windows
+ deps = [
+ "//tensorflow/contrib/kfac/python/ops:utils",
+ "//tensorflow/contrib/tpu",
+ "//tensorflow/python:array_ops",
+ "//tensorflow/python:client_testlib",
+ "//tensorflow/python:dtypes",
+ "//tensorflow/python:framework_ops",
+ "//tensorflow/python:linalg_ops",
+ "//tensorflow/python:random_seed",
+ "//tensorflow/python:variable_scope",
+ "//tensorflow/python:variables",
+ "//third_party/py/numpy",
+ ],
+)
+
+py_test(
+ name = "op_queue_test",
+ srcs = ["op_queue_test.py"],
+ srcs_version = "PY2AND3",
+ deps = [
+ "//tensorflow/contrib/kfac/python/ops:op_queue",
+ "//tensorflow/python:client_testlib",
+ "//tensorflow/python:framework_ops",
+ "//tensorflow/python:math_ops",
+ ],
+)
+
+py_test(
+ name = "loss_functions_test",
+ srcs = ["loss_functions_test.py"],
+ srcs_version = "PY2AND3",
+ deps = [
+ "//tensorflow/contrib/kfac/python/ops:loss_functions",
+ "//tensorflow/python:array_ops",
+ "//tensorflow/python:client_testlib",
+ "//tensorflow/python:constant_op",
+ "//tensorflow/python:framework_ops",
+ "//tensorflow/python:random_ops",
+ "//third_party/py/numpy",
+ ],
+)
diff --git a/tensorflow/contrib/kfac/python/kernel_tests/estimator_test.py b/tensorflow/contrib/kfac/python/kernel_tests/estimator_test.py
new file mode 100644
index 0000000000..0e65d419a3
--- /dev/null
+++ b/tensorflow/contrib/kfac/python/kernel_tests/estimator_test.py
@@ -0,0 +1,310 @@
+# Copyright 2017 The TensorFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+"""Tests for tf.contrib.kfac.estimator."""
+
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+import numpy as np
+
+from tensorflow.contrib.kfac.python.ops import estimator
+from tensorflow.contrib.kfac.python.ops import layer_collection as lc
+from tensorflow.contrib.kfac.python.ops import utils
+from tensorflow.python.framework import dtypes
+from tensorflow.python.framework import ops
+from tensorflow.python.ops import array_ops
+from tensorflow.python.ops import control_flow_ops
+from tensorflow.python.ops import init_ops
+from tensorflow.python.ops import linalg_ops
+from tensorflow.python.ops import math_ops
+from tensorflow.python.ops import random_ops
+from tensorflow.python.ops import variable_scope
+from tensorflow.python.ops import variables
+from tensorflow.python.platform import test
+from tensorflow.python.training import training_util
+
+_ALL_ESTIMATION_MODES = ["gradients", "empirical", "curvature_prop", "exact"]
+
+
+class EstimatorTest(test.TestCase):
+
+ def setUp(self):
+ self._graph = ops.Graph()
+ with self._graph.as_default():
+ self.layer_collection = lc.LayerCollection()
+
+ self.inputs = random_ops.random_normal((2, 2), dtype=dtypes.float32)
+ self.weights = variable_scope.get_variable(
+ "w", shape=(2, 2), dtype=dtypes.float32)
+ self.bias = variable_scope.get_variable(
+ "b", initializer=init_ops.zeros_initializer(), shape=(2, 1))
+ self.output = math_ops.matmul(self.inputs, self.weights) + self.bias
+
+ # Only register the weights.
+ self.layer_collection.register_fully_connected(
+ params=(self.weights,), inputs=self.inputs, outputs=self.output)
+
+ self.outputs = math_ops.tanh(self.output)
+ self.targets = array_ops.zeros_like(self.outputs)
+ self.layer_collection.register_categorical_predictive_distribution(
+ logits=self.outputs, targets=self.targets)
+
+ def testEstimatorInitManualRegistration(self):
+ with self._graph.as_default():
+ # We should be able to build an estimator for only the registered vars.
+ estimator.FisherEstimatorRoundRobin(
+ variables=[self.weights],
+ cov_ema_decay=0.1,
+ damping=0.2,
+ layer_collection=self.layer_collection
+ )
+
+ # Check that we throw an error if we try to build an estimator for vars
+ # that were not manually registered.
+ with self.assertRaises(ValueError):
+ est = estimator.FisherEstimatorRoundRobin(
+ variables=[self.weights, self.bias],
+ cov_ema_decay=0.1,
+ damping=0.2,
+ layer_collection=self.layer_collection
+ )
+ est.make_vars_and_create_op_thunks()
+
+ # Check that we throw an error if we don't include registered variables,
+ # i.e. self.weights
+ with self.assertRaises(ValueError):
+ est = estimator.FisherEstimatorRoundRobin(
+ variables=[],
+ cov_ema_decay=0.1,
+ damping=0.2,
+ layer_collection=self.layer_collection)
+ est.make_vars_and_create_op_thunks()
+
+ @test.mock.patch.object(utils.SubGraph, "variable_uses", return_value=42)
+ def testVariableWrongNumberOfUses(self, mock_uses):
+ with self.assertRaises(ValueError):
+ est = estimator.FisherEstimatorRoundRobin(
+ variables=[self.weights],
+ cov_ema_decay=0.1,
+ damping=0.2,
+ layer_collection=self.layer_collection)
+ est.make_vars_and_create_op_thunks()
+
+ def testInvalidEstimationMode(self):
+ with self.assertRaises(ValueError):
+ est = estimator.FisherEstimatorRoundRobin(
+ variables=[self.weights],
+ cov_ema_decay=0.1,
+ damping=0.2,
+ layer_collection=self.layer_collection,
+ estimation_mode="not_a_real_mode")
+ est.make_vars_and_create_op_thunks()
+
+ def testGradientsModeBuild(self):
+ with self._graph.as_default():
+ est = estimator.FisherEstimatorRoundRobin(
+ variables=[self.weights],
+ cov_ema_decay=0.1,
+ damping=0.2,
+ layer_collection=self.layer_collection,
+ estimation_mode="gradients")
+ est.make_vars_and_create_op_thunks()
+
+ def testEmpiricalModeBuild(self):
+ with self._graph.as_default():
+ est = estimator.FisherEstimatorRoundRobin(
+ variables=[self.weights],
+ cov_ema_decay=0.1,
+ damping=0.2,
+ layer_collection=self.layer_collection,
+ estimation_mode="empirical")
+ est.make_vars_and_create_op_thunks()
+
+ def testCurvaturePropModeBuild(self):
+ with self._graph.as_default():
+ est = estimator.FisherEstimatorRoundRobin(
+ variables=[self.weights],
+ cov_ema_decay=0.1,
+ damping=0.2,
+ layer_collection=self.layer_collection,
+ estimation_mode="curvature_prop")
+ est.make_vars_and_create_op_thunks()
+
+ def testExactModeBuild(self):
+ with self._graph.as_default():
+ est = estimator.FisherEstimatorRoundRobin(
+ variables=[self.weights],
+ cov_ema_decay=0.1,
+ damping=0.2,
+ layer_collection=self.layer_collection,
+ estimation_mode="exact")
+ est.make_vars_and_create_op_thunks()
+
+ def test_cov_update_thunks(self):
+ """Ensures covariance update ops run once per global_step."""
+ with self._graph.as_default(), self.test_session() as sess:
+ fisher_estimator = estimator.FisherEstimatorRoundRobin(
+ variables=[self.weights],
+ layer_collection=self.layer_collection,
+ damping=0.2,
+ cov_ema_decay=0.0)
+
+ # Construct an op that executes one covariance update per step.
+ global_step = training_util.get_or_create_global_step()
+ (cov_variable_thunks, cov_update_op_thunks, _,
+ _) = fisher_estimator.create_ops_and_vars_thunks()
+ for thunk in cov_variable_thunks:
+ thunk()
+ cov_matrices = [
+ fisher_factor.get_cov()
+ for fisher_factor in self.layer_collection.get_factors()
+ ]
+ cov_update_op = control_flow_ops.case(
+ [(math_ops.equal(global_step, i), thunk)
+ for i, thunk in enumerate(cov_update_op_thunks)])
+ increment_global_step = global_step.assign_add(1)
+
+ sess.run(variables.global_variables_initializer())
+ initial_cov_values = sess.run(cov_matrices)
+
+ # Ensure there's one update per covariance matrix.
+ self.assertEqual(len(cov_matrices), len(cov_update_op_thunks))
+
+ # Test is no-op if only 1 covariance matrix.
+ assert len(cov_matrices) > 1
+
+ for i in range(len(cov_matrices)):
+ # Compare new and old covariance values
+ new_cov_values = sess.run(cov_matrices)
+ is_cov_equal = [
+ np.allclose(initial_cov_value, new_cov_value)
+ for (initial_cov_value,
+ new_cov_value) in zip(initial_cov_values, new_cov_values)
+ ]
+ num_cov_equal = sum(is_cov_equal)
+
+ # Ensure exactly one covariance matrix changes per step.
+ self.assertEqual(num_cov_equal, len(cov_matrices) - i)
+
+ # Run all covariance update ops.
+ sess.run(cov_update_op)
+ sess.run(increment_global_step)
+
+ def test_round_robin_placement(self):
+ """Check if the ops and variables are placed on devices correctly."""
+ with self._graph.as_default():
+ fisher_estimator = estimator.FisherEstimatorRoundRobin(
+ variables=[self.weights],
+ layer_collection=self.layer_collection,
+ damping=0.2,
+ cov_ema_decay=0.0,
+ cov_devices=["/cpu:{}".format(i) for i in range(2)],
+ inv_devices=["/cpu:{}".format(i) for i in range(2)])
+
+ # Construct an op that executes one covariance update per step.
+ (cov_update_thunks,
+ inv_update_thunks) = fisher_estimator.make_vars_and_create_op_thunks(
+ scope="test")
+ cov_update_ops = tuple(thunk() for thunk in cov_update_thunks)
+ inv_update_ops = tuple(thunk() for thunk in inv_update_thunks)
+ self.assertEqual(cov_update_ops[0].device, "/device:CPU:0")
+ self.assertEqual(cov_update_ops[1].device, "/device:CPU:1")
+ self.assertEqual(inv_update_ops[0].device, "/device:CPU:0")
+ self.assertEqual(inv_update_ops[1].device, "/device:CPU:1")
+ cov_matrices = [
+ fisher_factor.get_cov()
+ for fisher_factor in self.layer_collection.get_factors()
+ ]
+ inv_matrices = [
+ matrix
+ for fisher_factor in self.layer_collection.get_factors()
+ for matrix in fisher_factor._matpower_by_exp_and_damping.values()
+ ]
+ self.assertEqual(cov_matrices[0].device, "/device:CPU:0")
+ self.assertEqual(cov_matrices[1].device, "/device:CPU:1")
+ # Inverse matrices need to be explicitly placed.
+ self.assertEqual(inv_matrices[0].device, "")
+ self.assertEqual(inv_matrices[1].device, "")
+
+ def test_inv_update_thunks(self):
+ """Ensures inverse update ops run once per global_step."""
+ with self._graph.as_default(), self.test_session() as sess:
+ fisher_estimator = estimator.FisherEstimatorRoundRobin(
+ variables=[self.weights],
+ layer_collection=self.layer_collection,
+ damping=0.2,
+ cov_ema_decay=0.0)
+
+ # Construct op that updates one inverse per global step.
+ global_step = training_util.get_or_create_global_step()
+ (cov_variable_thunks, _, inv_variable_thunks,
+ inv_update_op_thunks) = fisher_estimator.create_ops_and_vars_thunks()
+ for thunk in cov_variable_thunks:
+ thunk()
+ for thunk in inv_variable_thunks:
+ thunk()
+ inv_matrices = [
+ matrix
+ for fisher_factor in self.layer_collection.get_factors()
+ for matrix in fisher_factor._matpower_by_exp_and_damping.values()
+ ]
+ inv_update_op = control_flow_ops.case(
+ [(math_ops.equal(global_step, i), thunk)
+ for i, thunk in enumerate(inv_update_op_thunks)])
+ increment_global_step = global_step.assign_add(1)
+
+ sess.run(variables.global_variables_initializer())
+ initial_inv_values = sess.run(inv_matrices)
+
+ # Ensure there's one update per inverse matrix. This is true as long as
+ # there's no fan-in/fan-out or parameter re-use.
+ self.assertEqual(len(inv_matrices), len(inv_update_op_thunks))
+
+ # Test is no-op if only 1 invariance matrix.
+ assert len(inv_matrices) > 1
+
+ # Assign each covariance matrix a value other than the identity. This
+ # ensures that the inverse matrices are updated to something different as
+ # well.
+ cov_matrices = [
+ fisher_factor.get_cov()
+ for fisher_factor in self.layer_collection.get_factors()
+ ]
+ sess.run([
+ cov_matrix.assign(2 * linalg_ops.eye(int(cov_matrix.shape[0])))
+ for cov_matrix in cov_matrices
+ ])
+
+ for i in range(len(inv_matrices)):
+ # Compare new and old inverse values
+ new_inv_values = sess.run(inv_matrices)
+ is_inv_equal = [
+ np.allclose(initial_inv_value, new_inv_value)
+ for (initial_inv_value,
+ new_inv_value) in zip(initial_inv_values, new_inv_values)
+ ]
+ num_inv_equal = sum(is_inv_equal)
+
+ # Ensure exactly one inverse matrix changes per step.
+ self.assertEqual(num_inv_equal, len(inv_matrices) - i)
+
+ # Run all inverse update ops.
+ sess.run(inv_update_op)
+ sess.run(increment_global_step)
+
+
+if __name__ == "__main__":
+ test.main()
diff --git a/tensorflow/contrib/kfac/python/kernel_tests/fisher_blocks_test.py b/tensorflow/contrib/kfac/python/kernel_tests/fisher_blocks_test.py
new file mode 100644
index 0000000000..86ec7a095a
--- /dev/null
+++ b/tensorflow/contrib/kfac/python/kernel_tests/fisher_blocks_test.py
@@ -0,0 +1,1018 @@
+# Copyright 2017 The TensorFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+"""Tests for tf.contrib.kfac.fisher_blocks."""
+
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+import numpy as np
+
+from tensorflow.contrib.kfac.python.ops import fisher_blocks as fb
+from tensorflow.contrib.kfac.python.ops import fisher_factors as ff
+from tensorflow.contrib.kfac.python.ops import layer_collection as lc
+from tensorflow.contrib.kfac.python.ops import linear_operator as lo
+from tensorflow.contrib.kfac.python.ops import utils
+from tensorflow.python.framework import ops
+from tensorflow.python.framework import random_seed
+from tensorflow.python.ops import array_ops
+from tensorflow.python.ops import linalg_ops
+from tensorflow.python.ops import math_ops
+from tensorflow.python.ops import random_ops
+from tensorflow.python.ops import state_ops
+from tensorflow.python.ops import variables as tf_variables
+from tensorflow.python.platform import test
+
+
+# We need to set these constants since the numerical values used in the tests
+# were chosen when these used to be the defaults.
+ff.set_global_constants(init_covariances_at_zero=False,
+ zero_debias=False,
+ init_inverses_at_zero=False)
+
+# TODO(b/78538100): As far as I can tell, all the tests that say "Make sure our
+# inverse is something other than the identity" are actually broken. They never
+# run the covariance update ops and so the inverse actually is the identity
+# (possible plus the damping term, which would still make it a multiple of the
+# identity).
+
+
+def _make_psd(dim):
+ """Constructs a PSD matrix of the given dimension."""
+ mat = np.ones((dim, dim), dtype=np.float32)
+ mat[np.arange(dim), np.arange(dim)] = 2. + np.arange(dim)
+ return array_ops.constant(mat)
+
+
+class UtilsTest(test.TestCase):
+
+ def testComputePiTracenorm(self):
+ with ops.Graph().as_default(), self.test_session() as sess:
+ random_seed.set_random_seed(200)
+ diag = ops.convert_to_tensor([1., 2., 0., 1.])
+ left_factor = lo.LinearOperatorDiag(diag)
+ right_factor = lo.LinearOperatorFullMatrix(array_ops.ones([2, 2]))
+
+ # pi is the sqrt of the left trace norm divided by the right trace norm
+ pi = fb.compute_pi_tracenorm(left_factor, right_factor)
+
+ pi_val = sess.run(pi)
+ self.assertEqual(1., pi_val)
+
+
+class FullFBTest(test.TestCase):
+
+ def testFullFBInitSingleTensor(self):
+ with ops.Graph().as_default():
+ random_seed.set_random_seed(200)
+ params = (array_ops.constant([1., 2.]), array_ops.constant(3.))
+ block = fb.FullFB(lc.LayerCollection(), params)
+ block.register_additional_tower(32)
+
+ self.assertAllEqual(params, block.tensors_to_compute_grads())
+
+ def testFullFBInitTensorTuple(self):
+ with ops.Graph().as_default():
+ random_seed.set_random_seed(200)
+ params = (array_ops.constant([1., 2.]), array_ops.constant(3.))
+ block = fb.FullFB(lc.LayerCollection(), params)
+ block.register_additional_tower(32)
+
+ self.assertAllEqual(params, block.tensors_to_compute_grads())
+
+ def testInstantiateFactors(self):
+ with ops.Graph().as_default():
+ random_seed.set_random_seed(200)
+ params = (array_ops.constant([1., 2.]), array_ops.constant(3.))
+ block = fb.FullFB(lc.LayerCollection(), params)
+ block.register_additional_tower(32)
+
+ grads = (params[0]**2, math_ops.sqrt(params[1]))
+ block.instantiate_factors(grads, 0.5)
+
+ def testMultiplyInverseTuple(self):
+ with ops.Graph().as_default(), self.test_session() as sess:
+ random_seed.set_random_seed(200)
+ params = (array_ops.constant([1., 2.]), array_ops.constant(3.))
+ block = fb.FullFB(lc.LayerCollection(), params)
+ block.register_additional_tower(32)
+ grads = (params[0]**2, math_ops.sqrt(params[1]))
+ block.instantiate_factors((grads,), 0.5)
+ block._factor.instantiate_cov_variables()
+ block.register_inverse()
+ block._factor.instantiate_inv_variables()
+
+ # Make sure our inverse is something other than the identity.
+ sess.run(tf_variables.global_variables_initializer())
+ sess.run(block._factor.make_inverse_update_ops())
+
+ vector = array_ops.ones(3,) * 2
+ output = block.multiply_inverse(vector)
+
+ self.assertAllClose(sess.run(vector * 2 / 3.), sess.run(output))
+
+ def testMultiplyInverseNotTuple(self):
+ with ops.Graph().as_default(), self.test_session() as sess:
+ random_seed.set_random_seed(200)
+ params = array_ops.constant([[1.], [2.]])
+ block = fb.FullFB(lc.LayerCollection(), params)
+ block.register_additional_tower(32)
+ grads = params**2
+ block.instantiate_factors((grads,), 0.5)
+ block._factor.instantiate_cov_variables()
+ block.register_inverse()
+ block._factor.instantiate_inv_variables()
+
+ # Make sure our inverse is something other than the identity.
+ sess.run(tf_variables.global_variables_initializer())
+ sess.run(block._factor.make_inverse_update_ops())
+
+ vector = array_ops.ones(2,) * 2
+ output = block.multiply_inverse(vector)
+
+ self.assertAllClose(sess.run(vector * 2 / 3.), sess.run(output))
+
+ def testMultiplyInverseAgainstExplicit(self):
+ with ops.Graph().as_default(), self.test_session() as sess:
+ random_seed.set_random_seed(200)
+ params = (array_ops.constant([1., 2.]), array_ops.constant(3.))
+ block = fb.FullFB(lc.LayerCollection(), params)
+ block.register_additional_tower(32)
+ grads = (array_ops.constant([2., 3.]), array_ops.constant(4.))
+ damping = 0.5
+ block.instantiate_factors((grads,), damping)
+ block._factor.instantiate_cov_variables()
+ block.register_inverse()
+ block._factor.instantiate_inv_variables()
+
+ # Make sure our inverse is something other than the identity.
+ sess.run(state_ops.assign(block._factor._cov, _make_psd(3)))
+ sess.run(block._factor.make_inverse_update_ops())
+
+ v_flat = np.array([4., 5., 6.], dtype=np.float32)
+ vector = utils.column_to_tensors(params, array_ops.constant(v_flat))
+ output = block.multiply_inverse(vector)
+ output_flat = sess.run(utils.tensors_to_column(output)).ravel()
+
+ full = sess.run(block.full_fisher_block())
+ explicit = np.dot(np.linalg.inv(full + damping * np.eye(3)), v_flat)
+
+ self.assertAllClose(output_flat, explicit)
+
+
+class NaiveDiagonalFBTest(test.TestCase):
+
+ def testNaiveDiagonalFBInitSingleTensor(self):
+ with ops.Graph().as_default():
+ random_seed.set_random_seed(200)
+ params = (array_ops.constant([1., 2.]), array_ops.constant(3.))
+ block = fb.NaiveDiagonalFB(lc.LayerCollection(), params)
+ block.register_additional_tower(32)
+
+ self.assertAllEqual(params, block.tensors_to_compute_grads())
+
+ def testNaiveDiagonalFBInitTensorTuple(self):
+ with ops.Graph().as_default():
+ random_seed.set_random_seed(200)
+ params = (array_ops.constant([1., 2.]), array_ops.constant(3.))
+ block = fb.NaiveDiagonalFB(lc.LayerCollection(), params)
+ block.register_additional_tower(32)
+
+ self.assertAllEqual(params, block.tensors_to_compute_grads())
+
+ def testInstantiateFactors(self):
+ with ops.Graph().as_default():
+ random_seed.set_random_seed(200)
+ params = (array_ops.constant([1., 2.]), array_ops.constant(3.))
+ block = fb.NaiveDiagonalFB(lc.LayerCollection(), params)
+ block.register_additional_tower(32)
+
+ grads = (params[0]**2, math_ops.sqrt(params[1]))
+ block.instantiate_factors(grads, 0.5)
+
+ def testMultiplyInverseTuple(self):
+ with ops.Graph().as_default(), self.test_session() as sess:
+ random_seed.set_random_seed(200)
+ params = (array_ops.constant([1., 2.]), array_ops.constant(3.))
+ block = fb.NaiveDiagonalFB(lc.LayerCollection(), params)
+ block.register_additional_tower(32)
+ grads = (params[0]**2, math_ops.sqrt(params[1]))
+ block.instantiate_factors((grads,), 0.5)
+ block._factor.instantiate_cov_variables()
+
+ # Make sure our inverse is something other than the identity.
+ sess.run(tf_variables.global_variables_initializer())
+ sess.run(block._factor.make_inverse_update_ops())
+
+ vector = array_ops.ones(3,) * 2
+ output = block.multiply_inverse(vector)
+
+ self.assertAllClose(sess.run(vector * 2 / 3.), sess.run(output))
+
+ def testMultiplyInverseNotTuple(self):
+ with ops.Graph().as_default(), self.test_session() as sess:
+ random_seed.set_random_seed(200)
+ params = array_ops.constant([[1.], [2.]])
+ block = fb.NaiveDiagonalFB(lc.LayerCollection(), params)
+ block.register_additional_tower(32)
+ grads = params**2
+ block.instantiate_factors((grads,), 0.5)
+ block._factor.instantiate_cov_variables()
+
+ # Make sure our inverse is something other than the identity.
+ sess.run(tf_variables.global_variables_initializer())
+ sess.run(block._factor.make_inverse_update_ops())
+ vector = array_ops.ones(2,) * 2
+ output = block.multiply_inverse(vector)
+
+ self.assertAllClose(sess.run(vector * 2 / 3.), sess.run(output))
+
+ def testMultiplyInverseAgainstExplicit(self):
+ with ops.Graph().as_default(), self.test_session() as sess:
+ random_seed.set_random_seed(200)
+ params = (array_ops.constant([1., 2.]), array_ops.constant(3.))
+ block = fb.NaiveDiagonalFB(lc.LayerCollection(), params)
+ block.register_additional_tower(32)
+ grads = (params[0]**2, math_ops.sqrt(params[1]))
+ damping = 0.5
+ block.instantiate_factors((grads,), damping)
+ block._factor.instantiate_cov_variables()
+
+ cov = array_ops.reshape(array_ops.constant([2., 3., 4.]), [-1, 1])
+ sess.run(state_ops.assign(block._factor._cov, cov))
+ sess.run(block._factor.make_inverse_update_ops())
+
+ v_flat = np.array([4., 5., 6.], dtype=np.float32)
+ vector = utils.column_to_tensors(params, array_ops.constant(v_flat))
+ output = block.multiply_inverse(vector)
+ output_flat = sess.run(utils.tensors_to_column(output)).ravel()
+
+ full = sess.run(block.full_fisher_block())
+ explicit = np.dot(np.linalg.inv(full + damping * np.eye(3)), v_flat)
+ self.assertAllClose(output_flat, explicit)
+
+
+class FullyConnectedDiagonalFBTest(test.TestCase):
+
+ def setUp(self):
+ super(FullyConnectedDiagonalFBTest, self).setUp()
+
+ self.batch_size = 4
+ self.input_size = 6
+ self.output_size = 3
+
+ self.inputs = np.random.randn(self.batch_size, self.input_size).astype(
+ np.float32)
+ self.outputs = np.zeros([self.batch_size, self.output_size]).astype(
+ np.float32)
+ self.output_grads = np.random.randn(self.batch_size,
+ self.output_size).astype(np.float32)
+ self.w = np.random.randn(self.input_size, self.output_size).astype(
+ np.float32)
+ self.b = np.random.randn(self.output_size).astype(np.float32)
+
+ def fisherApprox(self, has_bias=False):
+ """Fisher approximation using default inputs."""
+ if has_bias:
+ inputs = np.concatenate(
+ [self.inputs, np.ones([self.batch_size, 1])], axis=1)
+ else:
+ inputs = self.inputs
+ return self.buildDiagonalFisherApproximation(inputs, self.output_grads)
+
+ def buildDiagonalFisherApproximation(self, inputs, output_grads):
+ """Builds explicit diagonal Fisher approximation.
+
+ Fisher's diagonal is (d loss / d w)'s elements squared for
+ d/dw = E[outer(input, output_grad)]
+
+ where the expectation is taken over examples.
+
+ Args:
+ inputs: np.array of shape [batch_size, input_size].
+ output_grads: np.array of shape [batch_size, output_size].
+
+ Returns:
+ Diagonal np.array of shape [num_params, num_params] for num_params =
+ input_size * output_size.
+ """
+ batch_size = inputs.shape[0]
+ assert output_grads.shape[0] == batch_size
+ input_size = inputs.shape[1]
+ output_size = output_grads.shape[1]
+ fisher_diag = np.zeros((input_size, output_size))
+ for i in range(batch_size):
+ fisher_diag += np.square(np.outer(inputs[i], output_grads[i]))
+ return np.diag(fisher_diag.flatten()) / batch_size
+
+ def testMultiply(self):
+ result, _ = self.runFisherBlockOps(self.w, [self.inputs], [self.outputs],
+ [self.output_grads])
+
+ # Construct Fisher-vector product.
+ expected_result = self.fisherApprox().dot(self.w.flatten())
+ expected_result = expected_result.reshape(
+ [self.input_size, self.output_size])
+
+ self.assertAllClose(expected_result, result)
+
+ def testMultiplyInverse(self):
+ _, result = self.runFisherBlockOps(self.w, [self.inputs], [self.outputs],
+ [self.output_grads])
+
+ # Construct inverse Fisher-vector product.
+ expected_result = np.linalg.inv(self.fisherApprox()).dot(self.w.flatten())
+ expected_result = expected_result.reshape(
+ [self.input_size, self.output_size])
+
+ self.assertAllClose(expected_result, result)
+
+ def testRegisterAdditionalTower(self):
+ """Ensure 1 big tower and 2 small towers are equivalent."""
+ multiply_result_big, multiply_inverse_result_big = self.runFisherBlockOps(
+ self.w, [self.inputs], [self.outputs], [self.output_grads])
+ multiply_result_small, multiply_inverse_result_small = (
+ self.runFisherBlockOps(self.w, np.split(self.inputs, 2),
+ np.split(self.outputs, 2),
+ np.split(self.output_grads, 2)))
+
+ self.assertAllClose(multiply_result_big, multiply_result_small)
+ self.assertAllClose(multiply_inverse_result_big,
+ multiply_inverse_result_small)
+
+ def testMultiplyHasBias(self):
+ result, _ = self.runFisherBlockOps((self.w, self.b), [self.inputs],
+ [self.outputs], [self.output_grads])
+ expected_result = self.fisherApprox(True).dot(
+ np.concatenate([self.w.flatten(), self.b.flatten()]))
+ expected_result = expected_result.reshape(
+ [self.input_size + 1, self.output_size])
+ expected_result = (expected_result[:-1], expected_result[-1])
+
+ self.assertEqual(len(result), 2)
+ self.assertAllClose(expected_result[0], result[0])
+ self.assertAllClose(expected_result[1], result[1])
+
+ def runFisherBlockOps(self, params, inputs, outputs, output_grads):
+ """Run Ops guaranteed by FisherBlock interface.
+
+ Args:
+ params: Tensor or 2-tuple of Tensors. Represents weights or weights and
+ bias of this layer.
+ inputs: list of Tensors of shape [batch_size, input_size]. Inputs to
+ layer.
+ outputs: list of Tensors of shape [batch_size, output_size].
+ Preactivations produced by layer.
+ output_grads: list of Tensors of shape [batch_size, output_size].
+ Gradient of loss with respect to 'outputs'.
+
+ Returns:
+ multiply_result: Result of FisherBlock.multiply(params)
+ multiply_inverse_result: Result of FisherBlock.multiply_inverse(params)
+ """
+ with ops.Graph().as_default(), self.test_session() as sess:
+ inputs = as_tensors(inputs)
+ outputs = as_tensors(outputs)
+ output_grads = as_tensors(output_grads)
+ params = as_tensors(params)
+
+ block = fb.FullyConnectedDiagonalFB(
+ lc.LayerCollection(), has_bias=isinstance(params, (tuple, list)))
+ for (i, o) in zip(inputs, outputs):
+ block.register_additional_tower(i, o)
+
+ block.instantiate_factors((output_grads,), damping=0.0)
+ block._factor.instantiate_cov_variables()
+
+ sess.run(tf_variables.global_variables_initializer())
+ sess.run(block._factor.make_covariance_update_op(0.0))
+ multiply_result = sess.run(block.multiply(params))
+ multiply_inverse_result = sess.run(block.multiply_inverse(params))
+
+ return multiply_result, multiply_inverse_result
+
+
+class EmbeddingKFACFBTest(test.TestCase):
+
+ def testInstantiateFactors(self):
+ with ops.Graph().as_default():
+ random_seed.set_random_seed(200)
+
+ # Create a Fisher Block.
+ vocab_size = 5
+ block = fb.EmbeddingKFACFB(lc.LayerCollection(), vocab_size)
+
+ # Add some examples.
+ inputs = array_ops.constant([[0, 1], [1, 2], [2, 3]])
+ outputs = array_ops.constant([[0.], [1.], [2.]])
+ block.register_additional_tower(inputs, outputs)
+
+ # Instantiate factor's variables. Ensure it doesn't fail.
+ grads = outputs**2.
+ damping = array_ops.constant(0.)
+ block.instantiate_factors(((grads,),), damping)
+
+ def testMultiplyInverse(self):
+ with ops.Graph().as_default(), self.test_session() as sess:
+ random_seed.set_random_seed(200)
+
+ # Create a Fisher Block.
+ vocab_size = 5
+ block = fb.EmbeddingKFACFB(lc.LayerCollection(), vocab_size)
+
+ # Add some examples.
+ inputs = array_ops.constant([[0, 1], [1, 2], [2, 3]])
+ outputs = array_ops.constant([[0.], [1.], [2.]])
+ block.register_additional_tower(inputs, outputs)
+
+ # Instantiate factor's variables. Ensure it doesn't fail.
+ grads = outputs**2.
+ damping = array_ops.constant(0.)
+ block.instantiate_factors(((grads,),), damping)
+ block._input_factor.instantiate_cov_variables()
+ block._output_factor.instantiate_cov_variables()
+ block.register_inverse()
+ block._input_factor.instantiate_inv_variables()
+ block._output_factor.instantiate_inv_variables()
+
+ # Create a sparse update.
+ indices = array_ops.constant([1, 3, 4])
+ values = array_ops.constant([[1.], [1.], [1.]])
+ sparse_vector = ops.IndexedSlices(
+ values, indices, dense_shape=[vocab_size, 1])
+ dense_vector = array_ops.reshape([0., 1., 0., 1., 1.], [vocab_size, 1])
+
+ # Compare Fisher-vector product against explicit result.
+ result = block.multiply_inverse(sparse_vector)
+ expected_result = linalg_ops.matrix_solve(block.full_fisher_block(),
+ dense_vector)
+
+ sess.run(tf_variables.global_variables_initializer())
+ self.assertAlmostEqual(
+ sess.run(expected_result[1]), sess.run(result.values[0]))
+ self.assertAlmostEqual(
+ sess.run(expected_result[3]), sess.run(result.values[1]))
+ self.assertAlmostEqual(
+ sess.run(expected_result[4]), sess.run(result.values[2]))
+
+
+class FullyConnectedKFACBasicFBTest(test.TestCase):
+
+ def testFullyConnectedKFACBasicFBInit(self):
+ with ops.Graph().as_default():
+ random_seed.set_random_seed(200)
+ inputs = array_ops.constant([1., 2.])
+ outputs = array_ops.constant([3., 4.])
+ block = fb.FullyConnectedKFACBasicFB(lc.LayerCollection())
+ block.register_additional_tower(inputs, outputs)
+
+ self.assertAllEqual([outputs], block.tensors_to_compute_grads())
+
+ def testInstantiateFactorsHasBias(self):
+ with ops.Graph().as_default():
+ random_seed.set_random_seed(200)
+ inputs = array_ops.constant([[1., 2.], [3., 4.]])
+ outputs = array_ops.constant([[3., 4.], [5., 6.]])
+ block = fb.FullyConnectedKFACBasicFB(lc.LayerCollection(), has_bias=True)
+ block.register_additional_tower(inputs, outputs)
+
+ grads = outputs**2
+ block.instantiate_factors(((grads,),), 0.5)
+
+ def testInstantiateFactorsNoBias(self):
+ with ops.Graph().as_default():
+ random_seed.set_random_seed(200)
+ inputs = array_ops.constant([[1., 2.], [3., 4.]])
+ outputs = array_ops.constant([[3., 4.], [5., 6.]])
+ block = fb.FullyConnectedKFACBasicFB(lc.LayerCollection(), has_bias=False)
+ block.register_additional_tower(inputs, outputs)
+
+ grads = outputs**2
+ block.instantiate_factors(((grads,),), 0.5)
+
+ def testMultiplyInverseTuple(self):
+ with ops.Graph().as_default(), self.test_session() as sess:
+ random_seed.set_random_seed(200)
+ inputs = array_ops.constant([[1., 2., 3.], [3., 4., 5.], [5., 6., 7.]])
+ outputs = array_ops.constant([[3., 4.], [5., 6.]])
+ block = fb.FullyConnectedKFACBasicFB(lc.LayerCollection(), has_bias=False)
+ block.register_additional_tower(inputs, outputs)
+ grads = outputs**2
+ block.instantiate_factors(((grads,),), 0.5)
+
+ block._input_factor.instantiate_cov_variables()
+ block._output_factor.instantiate_cov_variables()
+ block.register_inverse()
+ block._input_factor.instantiate_inv_variables()
+ block._output_factor.instantiate_inv_variables()
+
+ # Make sure our inverse is something other than the identity.
+ sess.run(tf_variables.global_variables_initializer())
+ sess.run(block._input_factor.make_inverse_update_ops())
+ sess.run(block._output_factor.make_inverse_update_ops())
+
+ vector = (
+ np.arange(2, 6).reshape(2, 2).astype(np.float32), #
+ np.arange(1, 3).reshape(2, 1).astype(np.float32))
+ output = block.multiply_inverse((array_ops.constant(vector[0]),
+ array_ops.constant(vector[1])))
+
+ output = sess.run(output)
+ self.assertAllClose([[0.686291, 1.029437], [1.372583, 1.715729]],
+ output[0])
+ self.assertAllClose([0.343146, 0.686291], output[1])
+
+ def testMultiplyInverseNotTuple(self):
+ with ops.Graph().as_default(), self.test_session() as sess:
+ random_seed.set_random_seed(200)
+ inputs = array_ops.constant([[1., 2.], [3., 4.]])
+ outputs = array_ops.constant([[3., 4.], [5., 6.]])
+ block = fb.FullyConnectedKFACBasicFB(lc.LayerCollection(), has_bias=False)
+ block.register_additional_tower(inputs, outputs)
+ grads = outputs**2
+ block.instantiate_factors(((grads,),), 0.5)
+ block._input_factor.instantiate_cov_variables()
+ block._output_factor.instantiate_cov_variables()
+ block.register_inverse()
+ block._input_factor.instantiate_inv_variables()
+ block._output_factor.instantiate_inv_variables()
+
+ # Make sure our inverse is something other than the identity.
+ sess.run(tf_variables.global_variables_initializer())
+ sess.run(block._input_factor.make_inverse_update_ops())
+ sess.run(block._output_factor.make_inverse_update_ops())
+
+ vector = np.arange(2, 6).reshape(2, 2).astype(np.float32)
+ output = block.multiply_inverse(array_ops.constant(vector))
+
+ self.assertAllClose([[0.686291, 1.029437], [1.372583, 1.715729]],
+ sess.run(output))
+
+ def testMultiplyInverseAgainstExplicit(self):
+ with ops.Graph().as_default(), self.test_session() as sess:
+ random_seed.set_random_seed(200)
+ input_dim, output_dim = 3, 2
+ inputs = array_ops.zeros([32, input_dim])
+ outputs = array_ops.zeros([32, output_dim])
+ params = array_ops.zeros([input_dim, output_dim])
+ block = fb.FullyConnectedKFACBasicFB(lc.LayerCollection(), has_bias=False)
+ block.register_additional_tower(inputs, outputs)
+ grads = outputs**2
+ damping = 0. # This test is only valid without damping.
+ block.instantiate_factors(((grads,),), damping)
+ block._input_factor.instantiate_cov_variables()
+ block._output_factor.instantiate_cov_variables()
+
+ sess.run(state_ops.assign(block._input_factor._cov, _make_psd(3)))
+ sess.run(state_ops.assign(block._output_factor._cov, _make_psd(2)))
+
+ block.register_inverse()
+ block._input_factor.instantiate_inv_variables()
+ block._output_factor.instantiate_inv_variables()
+
+ sess.run(block._input_factor.make_inverse_update_ops())
+ sess.run(block._output_factor.make_inverse_update_ops())
+
+ v_flat = np.arange(6, dtype=np.float32)
+ vector = utils.column_to_tensors(params, array_ops.constant(v_flat))
+ output = block.multiply_inverse(vector)
+ output_flat = sess.run(utils.tensors_to_column(output)).ravel()
+
+ full = sess.run(block.full_fisher_block())
+ explicit = np.dot(np.linalg.inv(full + damping * np.eye(6)), v_flat)
+
+ self.assertAllClose(output_flat, explicit)
+
+
+class ConvDiagonalFBTest(test.TestCase):
+
+ def setUp(self):
+ super(ConvDiagonalFBTest, self).setUp()
+
+ self.batch_size = 2
+ self.height = 8
+ self.width = 4
+ self.input_channels = 6
+ self.output_channels = 3
+ self.kernel_size = 1
+
+ self.inputs = np.random.randn(self.batch_size, self.height, self.width,
+ self.input_channels).astype(np.float32)
+ self.outputs = np.zeros(
+ [self.batch_size, self.height, self.width,
+ self.output_channels]).astype(np.float32)
+ self.output_grads = np.random.randn(
+ self.batch_size, self.height, self.width, self.output_channels).astype(
+ np.float32)
+ self.w = np.random.randn(self.kernel_size, self.kernel_size,
+ self.input_channels, self.output_channels).astype(
+ np.float32)
+ self.b = np.random.randn(self.output_channels).astype(np.float32)
+
+ def fisherApprox(self, has_bias=False):
+ """Fisher approximation using default inputs."""
+ if has_bias:
+ inputs = np.concatenate(
+ [self.inputs,
+ np.ones([self.batch_size, self.height, self.width, 1])],
+ axis=-1)
+ else:
+ inputs = self.inputs
+ return self.buildDiagonalFisherApproximation(inputs, self.output_grads,
+ self.kernel_size)
+
+ def buildDiagonalFisherApproximation(self, inputs, output_grads, kernel_size):
+ r"""Builds explicit diagonal Fisher approximation.
+
+ Fisher's diagonal is (d loss / d w)'s elements squared for
+ d/dw = E[\sum_{loc} outer(input_{loc}, output_grad_{loc})]
+
+ where the expectation is taken over examples and the sum over (x, y)
+ locations upon which the convolution is applied.
+
+ Args:
+ inputs: np.array of shape [batch_size, height, width, input_channels].
+ output_grads: np.array of shape [batch_size, height, width,
+ output_channels].
+ kernel_size: int. height and width of kernel.
+
+ Returns:
+ Diagonal np.array of shape [num_params, num_params] for num_params =
+ kernel_size^2 * input_channels * output_channels.
+ """
+ batch_size, height, width, input_channels = inputs.shape
+ assert output_grads.shape[0] == batch_size
+ assert output_grads.shape[1] == height
+ assert output_grads.shape[2] == width
+ output_channels = output_grads.shape[3]
+
+ # If kernel_size == 1, then we don't need to worry about capturing context
+ # around the pixel upon which a convolution is applied. This makes testing
+ # easier.
+ assert kernel_size == 1, "kernel_size != 1 isn't supported."
+ num_locations = height * width
+ inputs = np.reshape(inputs, [batch_size, num_locations, input_channels])
+ output_grads = np.reshape(output_grads,
+ [batch_size, num_locations, output_channels])
+
+ fisher_diag = np.zeros((input_channels, output_channels))
+ for i in range(batch_size):
+ # Each example's approximation is a square(sum-of-outer-products).
+ example_fisher_diag = np.zeros((input_channels, output_channels))
+ for j in range(num_locations):
+ example_fisher_diag += np.outer(inputs[i, j], output_grads[i, j])
+ fisher_diag += np.square(example_fisher_diag)
+
+ # Normalize by batch_size (not num_locations).
+ return np.diag(fisher_diag.flatten()) / batch_size
+
+ def testMultiply(self):
+ result, _ = self.runFisherBlockOps(self.w, [self.inputs], [self.outputs],
+ [self.output_grads])
+
+ # Construct Fisher-vector product.
+ expected_result = self.fisherApprox().dot(self.w.flatten())
+ expected_result = expected_result.reshape([
+ self.kernel_size, self.kernel_size, self.input_channels,
+ self.output_channels
+ ])
+
+ self.assertAllClose(expected_result, result)
+
+ def testMultiplyInverse(self):
+ _, result = self.runFisherBlockOps(self.w, [self.inputs], [self.outputs],
+ [self.output_grads])
+
+ # Construct inverse Fisher-vector product.
+ expected_result = np.linalg.inv(self.fisherApprox()).dot(self.w.flatten())
+ expected_result = expected_result.reshape([
+ self.kernel_size, self.kernel_size, self.input_channels,
+ self.output_channels
+ ])
+
+ self.assertAllClose(expected_result, result, atol=1e-3)
+
+ def testRegisterAdditionalTower(self):
+ """Ensure 1 big tower and 2 small towers are equivalent."""
+ multiply_result_big, multiply_inverse_result_big = self.runFisherBlockOps(
+ self.w, [self.inputs], [self.outputs], [self.output_grads])
+ multiply_result_small, multiply_inverse_result_small = (
+ self.runFisherBlockOps(self.w, np.split(self.inputs, 2),
+ np.split(self.outputs, 2),
+ np.split(self.output_grads, 2)))
+
+ self.assertAllClose(multiply_result_big, multiply_result_small)
+ self.assertAllClose(multiply_inverse_result_big,
+ multiply_inverse_result_small)
+
+ def testMultiplyHasBias(self):
+ result, _ = self.runFisherBlockOps((self.w, self.b), [self.inputs],
+ [self.outputs], [self.output_grads])
+ # Clone 'b' along 'input_channels' dimension.
+ b_filter = np.tile(
+ np.reshape(self.b, [1, 1, 1, self.output_channels]),
+ [self.kernel_size, self.kernel_size, 1, 1])
+ params = np.concatenate([self.w, b_filter], axis=2)
+ expected_result = self.fisherApprox(True).dot(params.flatten())
+
+ # Extract 'b' from concatenated parameters.
+ expected_result = expected_result.reshape([
+ self.kernel_size, self.kernel_size, self.input_channels + 1,
+ self.output_channels
+ ])
+ expected_result = (expected_result[:, :, 0:-1, :],
+ np.reshape(expected_result[:, :, -1, :],
+ [self.output_channels]))
+
+ self.assertEqual(len(result), 2)
+ self.assertAllClose(expected_result[0], result[0])
+ self.assertAllClose(expected_result[1], result[1])
+
+ def runFisherBlockOps(self, params, inputs, outputs, output_grads):
+ """Run Ops guaranteed by FisherBlock interface.
+
+ Args:
+ params: Tensor or 2-tuple of Tensors. Represents weights or weights and
+ bias of this layer.
+ inputs: list of Tensors of shape [batch_size, input_size]. Inputs to
+ layer.
+ outputs: list of Tensors of shape [batch_size, output_size].
+ Preactivations produced by layer.
+ output_grads: list of Tensors of shape [batch_size, output_size].
+ Gradient of loss with respect to 'outputs'.
+
+ Returns:
+ multiply_result: Result of FisherBlock.multiply(params)
+ multiply_inverse_result: Result of FisherBlock.multiply_inverse(params)
+ """
+ with ops.Graph().as_default(), self.test_session() as sess:
+ inputs = as_tensors(inputs)
+ outputs = as_tensors(outputs)
+ output_grads = as_tensors(output_grads)
+ params = as_tensors(params)
+
+ block = fb.ConvDiagonalFB(
+ lc.LayerCollection(), params, strides=[1, 1, 1, 1], padding='SAME')
+ for (i, o) in zip(inputs, outputs):
+ block.register_additional_tower(i, o)
+
+ block.instantiate_factors((output_grads,), damping=0.0)
+ block._factor.instantiate_cov_variables()
+
+ sess.run(tf_variables.global_variables_initializer())
+ sess.run(block._factor.make_covariance_update_op(0.0))
+ multiply_result = sess.run(block.multiply(params))
+ multiply_inverse_result = sess.run(block.multiply_inverse(params))
+
+ return multiply_result, multiply_inverse_result
+
+
+class DepthwiseConvKFCBasicFBTest(test.TestCase):
+
+ def testInstantiateFactors(self):
+ with ops.Graph().as_default():
+ random_seed.set_random_seed(200)
+ params = random_ops.random_normal((3, 3, 8, 2))
+ inputs = random_ops.random_normal((32, 5, 5, 8))
+ outputs = random_ops.random_normal((32, 5, 5, 16))
+ layer_collection = lc.LayerCollection()
+ block = fb.DepthwiseConvKFCBasicFB(
+ layer_collection, params=params, strides=[1, 1, 1, 1], padding='SAME')
+ block.register_additional_tower(inputs, outputs)
+ grads = outputs**2
+ block.instantiate_factors(([grads],), 0.5)
+
+ def testMultiplyInverse(self):
+ with ops.Graph().as_default(), self.test_session() as sess:
+ random_seed.set_random_seed(200)
+ params = random_ops.random_normal((3, 3, 8, 2))
+ inputs = random_ops.random_normal((32, 5, 5, 8))
+ outputs = random_ops.random_normal((32, 5, 5, 16))
+ layer_collection = lc.LayerCollection()
+ block = fb.DepthwiseConvKFCBasicFB(
+ layer_collection, params=params, strides=[1, 1, 1, 1], padding='SAME')
+ block.register_additional_tower(inputs, outputs)
+ grads = outputs**2
+ block.instantiate_factors(([grads],), 0.5)
+ block._input_factor.instantiate_cov_variables()
+ block._output_factor.instantiate_cov_variables()
+ block.register_inverse()
+ block._input_factor.instantiate_inv_variables()
+ block._output_factor.instantiate_inv_variables()
+
+ # Ensure inverse update op doesn't crash.
+ sess.run(tf_variables.global_variables_initializer())
+ sess.run([
+ factor.make_inverse_update_ops()
+ for factor in layer_collection.get_factors()
+ ])
+
+ # Ensure inverse-vector multiply doesn't crash.
+ output = block.multiply_inverse(params)
+ sess.run(output)
+
+ # Ensure same shape.
+ self.assertAllEqual(output.shape, params.shape)
+
+
+class ConvKFCBasicFBTest(test.TestCase):
+
+ def _testConvKFCBasicFBInitParams(self, params):
+ with ops.Graph().as_default():
+ random_seed.set_random_seed(200)
+ if isinstance(params, (list, tuple)):
+ params = [array_ops.constant(param) for param in params]
+ else:
+ params = array_ops.constant(params)
+ inputs = random_ops.random_normal((2, 2, 2))
+ outputs = random_ops.random_normal((2, 2, 2))
+ block = fb.ConvKFCBasicFB(
+ lc.LayerCollection(), params=params, padding='SAME')
+ block.register_additional_tower(inputs, outputs)
+
+ self.assertAllEqual([outputs], block.tensors_to_compute_grads())
+
+ def testConvKFCBasicFBInitParamsParamsTuple(self):
+ self._testConvKFCBasicFBInitParams([np.ones([1, 2, 2]), np.ones([2])])
+
+ def testConvKFCBasicFBInitParamsParamsSingle(self):
+ self._testConvKFCBasicFBInitParams([np.ones([1, 2, 2])])
+
+ def testMultiplyInverseTuple(self):
+ with ops.Graph().as_default(), self.test_session() as sess:
+ random_seed.set_random_seed(200)
+ params = random_ops.random_normal((2, 2, 2, 2))
+ inputs = random_ops.random_normal((2, 2, 2, 2))
+ outputs = random_ops.random_normal((2, 2, 2, 2))
+ block = fb.ConvKFCBasicFB(
+ lc.LayerCollection(), params=params, padding='SAME')
+ block.register_additional_tower(inputs, outputs)
+ grads = outputs**2
+ block.instantiate_factors(((grads,),), 0.5)
+ block._input_factor.instantiate_cov_variables()
+ block._output_factor.instantiate_cov_variables()
+ block.register_inverse()
+ block._input_factor.instantiate_inv_variables()
+ block._output_factor.instantiate_inv_variables()
+
+ # Make sure our inverse is something other than the identity.
+ sess.run(tf_variables.global_variables_initializer())
+ sess.run(block._input_factor.make_inverse_update_ops())
+ sess.run(block._output_factor.make_inverse_update_ops())
+
+ vector = (np.arange(1, 15).reshape(7, 2).astype(np.float32),
+ np.arange(2, 4).reshape(2, 1).astype(np.float32))
+ output = block.multiply_inverse((array_ops.constant(vector[0]),
+ array_ops.constant(vector[1])))
+
+ output = sess.run(output)
+ self.assertAllClose([0.136455, 0.27291], output[0][0])
+ self.assertAllClose([0.27291, 0.409365], output[1])
+
+ def testMultiplyInverseNotTuple(self):
+ with ops.Graph().as_default(), self.test_session() as sess:
+ random_seed.set_random_seed(200)
+ params = random_ops.random_normal((2, 2, 2, 2))
+ inputs = random_ops.random_normal((2, 2, 2, 2))
+ outputs = random_ops.random_normal((2, 2, 2, 2))
+ block = fb.ConvKFCBasicFB(
+ lc.LayerCollection(), params=params, padding='SAME')
+ block.register_additional_tower(inputs, outputs)
+ self.assertFalse(block._has_bias)
+ grads = outputs**2
+ block.instantiate_factors(((grads,),), 0.5)
+ block._input_factor.instantiate_cov_variables()
+ block._output_factor.instantiate_cov_variables()
+ block.register_inverse()
+ block._input_factor.instantiate_inv_variables()
+ block._output_factor.instantiate_inv_variables()
+
+ # Make sure our inverse is something other than the identity.
+ sess.run(tf_variables.global_variables_initializer())
+ sess.run(block._input_factor.make_inverse_update_ops())
+ sess.run(block._output_factor.make_inverse_update_ops())
+
+ vector = np.arange(1, 17).reshape(8, 2).astype(np.float32)
+ output = block.multiply_inverse(array_ops.constant(vector))
+
+ self.assertAllClose([0.136455, 0.27291], sess.run(output)[0])
+
+ def testMultiplyInverseNotTupleWithBias(self):
+ with ops.Graph().as_default(), self.test_session() as sess:
+ random_seed.set_random_seed(200)
+ params = [random_ops.random_normal((2, 2, 2, 2))]
+ inputs = random_ops.random_normal((2, 2, 2, 2))
+ outputs = random_ops.random_normal((2, 2, 2, 2))
+ block = fb.ConvKFCBasicFB(
+ lc.LayerCollection(), params=params, padding='SAME')
+ block.register_additional_tower(inputs, outputs)
+ self.assertTrue(block._has_bias)
+ grads = outputs**2
+ block.instantiate_factors(((grads,),), 0.5)
+ block._input_factor.instantiate_cov_variables()
+ block._output_factor.instantiate_cov_variables()
+ block.register_inverse()
+ block._input_factor.instantiate_inv_variables()
+ block._output_factor.instantiate_inv_variables()
+
+ # Make sure our inverse is something other than the identity.
+ sess.run(tf_variables.global_variables_initializer())
+ sess.run(block._input_factor.make_inverse_update_ops())
+ sess.run(block._output_factor.make_inverse_update_ops())
+
+ vector = np.arange(1, 19).reshape(9, 2).astype(np.float32)
+ output = block.multiply_inverse(array_ops.constant(vector))
+
+ self.assertAllClose([0.136455, 0.27291], sess.run(output)[0])
+
+ def testMultiplyInverseAgainstExplicit(self):
+ with ops.Graph().as_default(), self.test_session() as sess:
+ random_seed.set_random_seed(200)
+ params = array_ops.zeros((2, 2, 2, 2))
+ inputs = array_ops.zeros((2, 2, 2, 2))
+ outputs = array_ops.zeros((2, 2, 2, 2))
+ block = fb.ConvKFCBasicFB(
+ lc.LayerCollection(), params=params, padding='SAME')
+ block.register_additional_tower(inputs, outputs)
+ grads = outputs**2
+ damping = 0. # This test is only valid without damping.
+ block.instantiate_factors(((grads,),), damping)
+ block._input_factor.instantiate_cov_variables()
+ block._output_factor.instantiate_cov_variables()
+ block.register_inverse()
+ block._input_factor.instantiate_inv_variables()
+ block._output_factor.instantiate_inv_variables()
+
+ sess.run(state_ops.assign(block._input_factor._cov, _make_psd(8)))
+ sess.run(state_ops.assign(block._output_factor._cov, _make_psd(2)))
+ sess.run(block._input_factor.make_inverse_update_ops())
+ sess.run(block._output_factor.make_inverse_update_ops())
+
+ v_flat = np.arange(16, dtype=np.float32)
+ vector = utils.column_to_tensors(params, array_ops.constant(v_flat))
+ output = block.multiply_inverse(vector)
+ output_flat = sess.run(utils.tensors_to_column(output)).ravel()
+
+ full = sess.run(block.full_fisher_block())
+ explicit = np.dot(np.linalg.inv(full + damping * np.eye(16)), v_flat)
+
+ self.assertAllClose(output_flat, explicit)
+
+
+class FullyConnectedSeriesFBTest(test.TestCase):
+
+ def testFullyConnectedSeriesFBInit(self):
+ with ops.Graph().as_default():
+ random_seed.set_random_seed(200)
+ inputs = array_ops.constant([1., 2.])
+ outputs = array_ops.constant([3., 4.])
+ block = fb.FullyConnectedSeriesFB(lc.LayerCollection())
+ block.register_additional_tower([inputs], [outputs])
+ self.assertAllEqual([[outputs]], block.tensors_to_compute_grads())
+
+ def testInstantiateFactorsHasBias(self):
+ with ops.Graph().as_default():
+ random_seed.set_random_seed(200)
+ inputs = array_ops.constant([[1., 2.], [3., 4.]])
+ outputs = array_ops.constant([[3., 4.], [5., 6.]])
+ block = fb.FullyConnectedSeriesFB(
+ lc.LayerCollection(),
+ has_bias=True)
+ block.register_additional_tower([inputs], [outputs])
+ grads = outputs**2
+ block.instantiate_factors((((grads,),),), 0.5)
+
+ def testInstantiateFactorsNoBias(self):
+ with ops.Graph().as_default():
+ random_seed.set_random_seed(200)
+ inputs = array_ops.constant([[1., 2.], [3., 4.]])
+ outputs = array_ops.constant([[3., 4.], [5., 6.]])
+ block = fb.FullyConnectedSeriesFB(
+ lc.LayerCollection(),
+ has_bias=False)
+ block.register_additional_tower([inputs], [outputs])
+ grads = outputs**2
+ block.instantiate_factors((((grads,),),), 0.5)
+
+
+def as_tensors(tensor_or_tuple):
+ """Converts a potentially nested tuple of np.array to Tensors."""
+ if isinstance(tensor_or_tuple, (tuple, list)):
+ return tuple(as_tensors(t) for t in tensor_or_tuple)
+ return ops.convert_to_tensor(tensor_or_tuple)
+
+
+if __name__ == '__main__':
+ test.main()
diff --git a/tensorflow/contrib/kfac/python/kernel_tests/fisher_factors_test.py b/tensorflow/contrib/kfac/python/kernel_tests/fisher_factors_test.py
new file mode 100644
index 0000000000..fad47cd02f
--- /dev/null
+++ b/tensorflow/contrib/kfac/python/kernel_tests/fisher_factors_test.py
@@ -0,0 +1,955 @@
+# Copyright 2017 The TensorFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+"""Tests for tf.contrib.kfac.fisher_factors."""
+
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+import numpy as np
+import numpy.random as npr
+
+from tensorflow.contrib.kfac.python.ops import fisher_blocks as fb
+from tensorflow.contrib.kfac.python.ops import fisher_factors as ff
+from tensorflow.python.framework import constant_op
+from tensorflow.python.framework import dtypes
+from tensorflow.python.framework import ops as tf_ops
+from tensorflow.python.framework import random_seed
+from tensorflow.python.ops import array_ops
+from tensorflow.python.ops import gradients_impl
+from tensorflow.python.ops import math_ops
+from tensorflow.python.ops import random_ops
+from tensorflow.python.ops import variables as tf_variables
+from tensorflow.python.platform import test
+
+
+# We need to set these constants since the numerical values used in the tests
+# were chosen when these used to be the defaults.
+ff.set_global_constants(init_covariances_at_zero=False,
+ zero_debias=False,
+ init_inverses_at_zero=False)
+
+
+def make_damping_func(damping):
+ return fb._package_func(lambda: damping, damping)
+
+
+class FisherFactorTestingDummy(ff.FisherFactor):
+ """Dummy class to test the non-abstract methods on ff.FisherFactor."""
+
+ @property
+ def _var_scope(self):
+ return 'dummy/a_b_c'
+
+ @property
+ def _cov_shape(self):
+ raise NotImplementedError
+
+ @property
+ def _num_sources(self):
+ return 1
+
+ @property
+ def _dtype(self):
+ return dtypes.float32
+
+ def _compute_new_cov(self):
+ raise NotImplementedError
+
+ def instantiate_covariance(self):
+ pass
+
+ def make_inverse_update_ops(self):
+ return []
+
+ def get_cov(self):
+ return NotImplementedError
+
+ def instantiate_inv_variables(self):
+ return NotImplementedError
+
+ def _num_towers(self):
+ raise NotImplementedError
+
+ def _get_data_device(self):
+ raise NotImplementedError
+
+ def register_matpower(self, exp, damping_func):
+ raise NotImplementedError
+
+ def register_cholesky(self, damping_func):
+ raise NotImplementedError
+
+ def register_cholesky_inverse(self, damping_func):
+ raise NotImplementedError
+
+ def get_matpower(self, exp, damping_func):
+ raise NotImplementedError
+
+ def get_cholesky(self, damping_func):
+ raise NotImplementedError
+
+ def get_cholesky_inverse(self, damping_func):
+ raise NotImplementedError
+
+ def get_cov_as_linear_operator(self):
+ raise NotImplementedError
+
+
+class DenseSquareMatrixFactorTestingDummy(ff.DenseSquareMatrixFactor):
+ """Dummy class to test the non-abstract methods on ff.DenseSquareMatrixFactor.
+ """
+
+ def __init__(self, shape):
+ self._shape = shape
+ super(DenseSquareMatrixFactorTestingDummy, self).__init__()
+
+ @property
+ def _var_scope(self):
+ return 'dummy/a_b_c'
+
+ @property
+ def _cov_shape(self):
+ return self._shape
+
+ @property
+ def _num_sources(self):
+ return 1
+
+ @property
+ def _dtype(self):
+ return dtypes.float32
+
+ def _compute_new_cov(self):
+ raise NotImplementedError
+
+ def instantiate_covariance(self):
+ pass
+
+ def _num_towers(self):
+ raise NotImplementedError
+
+ def _get_data_device(self):
+ raise NotImplementedError
+
+
+class NumericalUtilsTest(test.TestCase):
+
+ def testComputeCovAgainstNumpy(self):
+ with tf_ops.Graph().as_default(), self.test_session() as sess:
+ npr.seed(0)
+ random_seed.set_random_seed(200)
+
+ x = npr.randn(100, 3)
+ cov = ff.compute_cov(array_ops.constant(x))
+ np_cov = np.dot(x.T, x) / x.shape[0]
+
+ self.assertAllClose(sess.run(cov), np_cov)
+
+ def testComputeCovAgainstNumpyWithAlternativeNormalizer(self):
+ with tf_ops.Graph().as_default(), self.test_session() as sess:
+ npr.seed(0)
+ random_seed.set_random_seed(200)
+
+ normalizer = 10.
+ x = npr.randn(100, 3)
+ cov = ff.compute_cov(array_ops.constant(x), normalizer=normalizer)
+ np_cov = np.dot(x.T, x) / normalizer
+
+ self.assertAllClose(sess.run(cov), np_cov)
+
+ def testAppendHomog(self):
+ with tf_ops.Graph().as_default(), self.test_session() as sess:
+ npr.seed(0)
+
+ m, n = 3, 4
+ a = npr.randn(m, n)
+ a_homog = ff.append_homog(array_ops.constant(a))
+ np_result = np.hstack([a, np.ones((m, 1))])
+
+ self.assertAllClose(sess.run(a_homog), np_result)
+
+
+class NameStringUtilFunctionTest(test.TestCase):
+
+ def _make_tensor(self):
+ x = array_ops.placeholder(dtypes.float64, (3, 1))
+ w = array_ops.constant(npr.RandomState(0).randn(3, 3))
+ y = math_ops.matmul(w, x)
+ g = gradients_impl.gradients(y, x)[0]
+ return g
+
+ def testScopeStringFromParamsSingleTensor(self):
+ with tf_ops.Graph().as_default():
+ g = self._make_tensor()
+ scope_string = ff.scope_string_from_params(g)
+ self.assertEqual('gradients_MatMul_grad_MatMul_1', scope_string)
+
+ def testScopeStringFromParamsMultipleTensors(self):
+ with tf_ops.Graph().as_default():
+ x = array_ops.constant(1,)
+ y = array_ops.constant(2,)
+ scope_string = ff.scope_string_from_params((x, y))
+ self.assertEqual('Const_Const_1', scope_string)
+
+ def testScopeStringFromParamsMultipleTypes(self):
+ with tf_ops.Graph().as_default():
+ x = array_ops.constant(1,)
+ y = array_ops.constant(2,)
+ scope_string = ff.scope_string_from_params([[1, 2, 3], 'foo', True, 4,
+ (x, y)])
+ self.assertEqual('1-2-3_foo_True_4_Const__Const_1', scope_string)
+
+ def testScopeStringFromParamsUnsupportedType(self):
+ with tf_ops.Graph().as_default():
+ x = array_ops.constant(1,)
+ y = array_ops.constant(2,)
+ unsupported = 1.2 # Floats are not supported.
+ with self.assertRaises(ValueError):
+ ff.scope_string_from_params([[1, 2, 3], 'foo', True, 4, (x, y),
+ unsupported])
+
+ def testScopeStringFromName(self):
+ with tf_ops.Graph().as_default():
+ g = self._make_tensor()
+ scope_string = ff.scope_string_from_name(g)
+ self.assertEqual('gradients_MatMul_grad_MatMul_1', scope_string)
+
+ def testScalarOrTensorToString(self):
+ with tf_ops.Graph().as_default():
+ self.assertEqual(ff.scalar_or_tensor_to_string(5.), repr(5.))
+
+ g = self._make_tensor()
+ scope_string = ff.scope_string_from_name(g)
+ self.assertEqual(ff.scalar_or_tensor_to_string(g), scope_string)
+
+
+class FisherFactorTest(test.TestCase):
+
+ def testMakeInverseUpdateOps(self):
+ with tf_ops.Graph().as_default():
+ random_seed.set_random_seed(200)
+ factor = FisherFactorTestingDummy()
+
+ self.assertEqual(0, len(factor.make_inverse_update_ops()))
+
+
+class DenseSquareMatrixFactorTest(test.TestCase):
+
+ def testRegisterDampedInverse(self):
+ with tf_ops.Graph().as_default():
+ random_seed.set_random_seed(200)
+ shape = [2, 2]
+ factor = DenseSquareMatrixFactorTestingDummy(shape)
+ factor_var_scope = 'dummy/a_b_c'
+
+ damping_funcs = [make_damping_func(0.1),
+ make_damping_func(0.1),
+ make_damping_func(1e-5),
+ make_damping_func(1e-5)]
+ for damping_func in damping_funcs:
+ factor.register_inverse(damping_func)
+
+ factor.instantiate_inv_variables()
+
+ inv = factor.get_inverse(damping_funcs[0]).to_dense()
+ self.assertEqual(inv, factor.get_inverse(damping_funcs[1]).to_dense())
+ self.assertNotEqual(inv, factor.get_inverse(damping_funcs[2]).to_dense())
+ self.assertEqual(factor.get_inverse(damping_funcs[2]).to_dense(),
+ factor.get_inverse(damping_funcs[3]).to_dense())
+ factor_vars = tf_ops.get_collection(tf_ops.GraphKeys.GLOBAL_VARIABLES,
+ factor_var_scope)
+ factor_tensors = (tf_ops.convert_to_tensor(var) for var in factor_vars)
+
+ self.assertEqual(set([inv,
+ factor.get_inverse(damping_funcs[2]).to_dense()]),
+ set(factor_tensors))
+ self.assertEqual(shape, inv.get_shape())
+
+ def testRegisterMatpower(self):
+ with tf_ops.Graph().as_default():
+ random_seed.set_random_seed(200)
+ shape = [3, 3]
+ factor = DenseSquareMatrixFactorTestingDummy(shape)
+ factor_var_scope = 'dummy/a_b_c'
+
+ # TODO(b/74201126): Change to using the same func for both once
+ # Topohash is in place.
+ damping_func_1 = make_damping_func(0.5)
+ damping_func_2 = make_damping_func(0.5)
+
+ factor.register_matpower(-0.5, damping_func_1)
+ factor.register_matpower(2, damping_func_2)
+
+ factor.instantiate_inv_variables()
+
+ factor_vars = tf_ops.get_collection(tf_ops.GraphKeys.GLOBAL_VARIABLES,
+ factor_var_scope)
+
+ factor_tensors = (tf_ops.convert_to_tensor(var) for var in factor_vars)
+
+ matpower1 = factor.get_matpower(-0.5, damping_func_1).to_dense()
+ matpower2 = factor.get_matpower(2, damping_func_2).to_dense()
+
+ self.assertEqual(set([matpower1, matpower2]), set(factor_tensors))
+
+ self.assertEqual(shape, matpower1.get_shape())
+ self.assertEqual(shape, matpower2.get_shape())
+
+ def testMakeInverseUpdateOps(self):
+ with tf_ops.Graph().as_default():
+ random_seed.set_random_seed(200)
+ factor = FisherFactorTestingDummy()
+
+ self.assertEqual(0, len(factor.make_inverse_update_ops()))
+
+ def testMakeInverseUpdateOpsManyInversesEigenDecomp(self):
+ with tf_ops.Graph().as_default(), self.test_session() as sess:
+ random_seed.set_random_seed(200)
+ cov = np.array([[1., 2.], [3., 4.]])
+ factor = DenseSquareMatrixFactorTestingDummy(cov.shape)
+ factor._cov = array_ops.constant(cov, dtype=dtypes.float32)
+
+ damping_funcs = []
+ for i in range(1, ff.EIGENVALUE_DECOMPOSITION_THRESHOLD + 1):
+ damping_funcs.append(make_damping_func(1./i))
+
+ for i in range(ff.EIGENVALUE_DECOMPOSITION_THRESHOLD):
+ factor.register_inverse(damping_funcs[i])
+
+ factor.instantiate_inv_variables()
+ ops = factor.make_inverse_update_ops()
+ self.assertEqual(1, len(ops))
+
+ sess.run(tf_variables.global_variables_initializer())
+ new_invs = []
+ sess.run(ops)
+ for i in range(ff.EIGENVALUE_DECOMPOSITION_THRESHOLD):
+ # The inverse op will assign the damped inverse of cov to the inv var.
+ new_invs.append(
+ sess.run(factor.get_inverse(damping_funcs[i]).to_dense()))
+
+ # We want to see that the new invs are all different from each other.
+ for i in range(len(new_invs)):
+ for j in range(i + 1, len(new_invs)):
+ # Just check the first element.
+ self.assertNotEqual(new_invs[i][0][0], new_invs[j][0][0])
+
+ def testMakeInverseUpdateOpsMatPowerEigenDecomp(self):
+ with tf_ops.Graph().as_default(), self.test_session() as sess:
+ random_seed.set_random_seed(200)
+ cov = np.array([[6., 2.], [2., 4.]])
+ factor = DenseSquareMatrixFactorTestingDummy(cov.shape)
+ factor._cov = array_ops.constant(cov, dtype=dtypes.float32)
+ exp = 2 # NOTE(mattjj): must be int to test with np.linalg.matrix_power
+ damping = 0.5
+ damping_func = make_damping_func(damping)
+
+ factor.register_matpower(exp, damping_func)
+ factor.instantiate_inv_variables()
+ ops = factor.make_inverse_update_ops()
+ self.assertEqual(1, len(ops))
+
+ sess.run(tf_variables.global_variables_initializer())
+ sess.run(ops[0])
+ matpower = sess.run(factor.get_matpower(exp, damping_func).to_dense())
+ matpower_np = np.linalg.matrix_power(cov + np.eye(2) * damping, exp)
+ self.assertAllClose(matpower, matpower_np)
+
+ def testMakeInverseUpdateOpsNoEigenDecomp(self):
+ with tf_ops.Graph().as_default(), self.test_session() as sess:
+ random_seed.set_random_seed(200)
+ cov = np.array([[5., 2.], [2., 4.]]) # NOTE(mattjj): must be symmetric
+ factor = DenseSquareMatrixFactorTestingDummy(cov.shape)
+ factor._cov = array_ops.constant(cov, dtype=dtypes.float32)
+
+ damping_func = make_damping_func(0)
+
+ factor.register_inverse(damping_func)
+ factor.instantiate_inv_variables()
+ ops = factor.make_inverse_update_ops()
+ self.assertEqual(1, len(ops))
+
+ sess.run(tf_variables.global_variables_initializer())
+ # The inverse op will assign the damped inverse of cov to the inv var.
+ old_inv = sess.run(factor.get_inverse(damping_func).to_dense())
+ self.assertAllClose(
+ sess.run(ff.inverse_initializer(cov.shape, dtypes.float32)), old_inv)
+
+ sess.run(ops)
+ new_inv = sess.run(factor.get_inverse(damping_func).to_dense())
+ self.assertAllClose(new_inv, np.linalg.inv(cov))
+
+
+class FullFactorTest(test.TestCase):
+
+ def testFullFactorInit(self):
+ with tf_ops.Graph().as_default():
+ random_seed.set_random_seed(200)
+ tensor = array_ops.ones((2, 3), name='a/b/c')
+ factor = ff.FullFactor((tensor,), 32)
+ factor.instantiate_cov_variables()
+ self.assertEqual([6, 6], factor.get_cov().get_shape().as_list())
+
+ def testFullFactorInitFloat64(self):
+ with tf_ops.Graph().as_default():
+ dtype = dtypes.float64_ref
+ random_seed.set_random_seed(200)
+ tensor = array_ops.ones((2, 3), dtype=dtype, name='a/b/c')
+ factor = ff.FullFactor((tensor,), 32)
+ factor.instantiate_cov_variables()
+ cov = factor.get_cov()
+ self.assertEqual(cov.dtype, dtype)
+ self.assertEqual([6, 6], cov.get_shape().as_list())
+
+ def testMakeCovarianceUpdateOp(self):
+ with tf_ops.Graph().as_default(), self.test_session() as sess:
+ random_seed.set_random_seed(200)
+ tensor = array_ops.constant([1., 2.], name='a/b/c')
+ factor = ff.FullFactor((tensor,), 2)
+ factor.instantiate_cov_variables()
+
+ sess.run(tf_variables.global_variables_initializer())
+ new_cov = sess.run(factor.make_covariance_update_op(.5))
+ self.assertAllClose([[0.75, 0.5], [0.5, 1.5]], new_cov)
+
+
+class NaiveDiagonalFactorTest(test.TestCase):
+
+ def testNaiveDiagonalFactorInit(self):
+ with tf_ops.Graph().as_default():
+ random_seed.set_random_seed(200)
+ tensor = array_ops.ones((2, 3), name='a/b/c')
+ factor = ff.NaiveDiagonalFactor((tensor,), 32)
+ factor.instantiate_cov_variables()
+ self.assertEqual([6, 1], factor.get_cov().get_shape().as_list())
+
+ def testNaiveDiagonalFactorInitFloat64(self):
+ with tf_ops.Graph().as_default():
+ dtype = dtypes.float64_ref
+ random_seed.set_random_seed(200)
+ tensor = array_ops.ones((2, 3), dtype=dtype, name='a/b/c')
+ factor = ff.NaiveDiagonalFactor((tensor,), 32)
+ factor.instantiate_cov_variables()
+ cov = factor.get_cov()
+ self.assertEqual(cov.dtype, dtype)
+ self.assertEqual([6, 1], cov.get_shape().as_list())
+
+ def testMakeCovarianceUpdateOp(self):
+ with tf_ops.Graph().as_default(), self.test_session() as sess:
+ random_seed.set_random_seed(200)
+ tensor = array_ops.constant([1., 2.], name='a/b/c')
+ factor = ff.NaiveDiagonalFactor((tensor,), 2)
+ factor.instantiate_cov_variables()
+
+ sess.run(tf_variables.global_variables_initializer())
+ new_cov = sess.run(factor.make_covariance_update_op(.5))
+ self.assertAllClose([[0.75], [1.5]], new_cov)
+
+
+class EmbeddingInputKroneckerFactorTest(test.TestCase):
+
+ def testInitialization(self):
+ with tf_ops.Graph().as_default():
+ input_ids = array_ops.constant([[0], [1], [4]])
+ vocab_size = 5
+ factor = ff.EmbeddingInputKroneckerFactor((input_ids,), vocab_size)
+ factor.instantiate_cov_variables()
+ cov = factor.get_cov()
+ self.assertEqual(cov.shape.as_list(), [vocab_size])
+
+ def testCovarianceUpdateOp(self):
+ with tf_ops.Graph().as_default():
+ input_ids = array_ops.constant([[0], [1], [4]])
+ vocab_size = 5
+ factor = ff.EmbeddingInputKroneckerFactor((input_ids,), vocab_size)
+ factor.instantiate_cov_variables()
+ cov_update_op = factor.make_covariance_update_op(0.0)
+
+ with self.test_session() as sess:
+ sess.run(tf_variables.global_variables_initializer())
+ new_cov = sess.run(cov_update_op)
+ self.assertAllClose(np.array([1., 1., 0., 0., 1.]) / 3., new_cov)
+
+
+class ConvDiagonalFactorTest(test.TestCase):
+
+ def setUp(self):
+ self.batch_size = 10
+ self.height = self.width = 32
+ self.in_channels = 3
+ self.out_channels = 1
+ self.kernel_height = self.kernel_width = 3
+ self.strides = [1, 2, 2, 1]
+ self.data_format = 'NHWC'
+ self.padding = 'SAME'
+ self.kernel_shape = [
+ self.kernel_height, self.kernel_width, self.in_channels,
+ self.out_channels
+ ]
+
+ def testInit(self):
+ with tf_ops.Graph().as_default():
+ inputs = random_ops.random_uniform(
+ [self.batch_size, self.height, self.width, self.in_channels])
+ outputs_grads = [
+ random_ops.random_uniform([
+ self.batch_size, self.height // self.strides[1],
+ self.width // self.strides[2], self.out_channels
+ ]) for _ in range(3)
+ ]
+
+ factor = ff.ConvDiagonalFactor(
+ (inputs,),
+ (outputs_grads,),
+ self.kernel_shape,
+ self.strides,
+ self.padding,
+ data_format=self.data_format)
+ factor.instantiate_cov_variables()
+
+ # Ensure covariance matrix's shape makes sense.
+ self.assertEqual([
+ self.kernel_height * self.kernel_width * self.in_channels,
+ self.out_channels
+ ],
+ factor.get_cov().shape.as_list())
+
+ def testMakeCovarianceUpdateOp(self):
+ with tf_ops.Graph().as_default():
+ # Construct all arguments such that convolution kernel is applied in
+ # exactly one spatial location.
+ inputs = np.random.randn(
+ 1, # batch_size
+ self.kernel_height,
+ self.kernel_width,
+ self.in_channels) # in_channels
+ outputs_grad = np.random.randn(
+ 1, # batch_size
+ 1, # output_height
+ 1, # output_width
+ self.out_channels)
+
+ factor = ff.ConvDiagonalFactor(
+ (constant_op.constant(inputs),),
+ ((constant_op.constant(outputs_grad),),),
+ self.kernel_shape,
+ strides=[1, 1, 1, 1],
+ padding='VALID')
+ factor.instantiate_cov_variables()
+
+ # Completely forget initial value on first update.
+ cov_update_op = factor.make_covariance_update_op(0.0)
+
+ # Ensure new covariance value is same as outer-product of inputs/outputs
+ # vectorized, squared.
+ with self.test_session() as sess:
+ sess.run(tf_variables.global_variables_initializer())
+ cov = sess.run(cov_update_op)
+ expected_cov = np.outer(inputs.flatten(), outputs_grad.flatten())**2
+ self.assertAllClose(expected_cov, cov)
+
+ def testHasBias(self):
+ with tf_ops.Graph().as_default():
+ inputs = random_ops.random_uniform(
+ [self.batch_size, self.height, self.width, self.in_channels])
+ outputs_grads = [
+ random_ops.random_uniform([
+ self.batch_size, self.height // self.strides[1],
+ self.width // self.strides[2], self.out_channels
+ ]) for _ in range(3)
+ ]
+
+ factor = ff.ConvDiagonalFactor(
+ (inputs,),
+ (outputs_grads,),
+ self.kernel_shape,
+ self.strides,
+ self.padding,
+ data_format=self.data_format,
+ has_bias=True)
+ factor.instantiate_cov_variables()
+
+ # Ensure shape accounts for bias.
+ self.assertEqual([
+ self.kernel_height * self.kernel_width * self.in_channels + 1,
+ self.out_channels
+ ],
+ factor.get_cov().shape.as_list())
+
+ # Ensure update op doesn't crash.
+ cov_update_op = factor.make_covariance_update_op(0.0)
+ with self.test_session() as sess:
+ sess.run(tf_variables.global_variables_initializer())
+ sess.run(cov_update_op)
+
+
+class FullyConnectedKroneckerFactorTest(test.TestCase):
+
+ def _testFullyConnectedKroneckerFactorInit(self,
+ has_bias,
+ final_shape,
+ dtype=dtypes.float32_ref):
+ with tf_ops.Graph().as_default():
+ random_seed.set_random_seed(200)
+ tensor = array_ops.ones((2, 3), dtype=dtype, name='a/b/c')
+ factor = ff.FullyConnectedKroneckerFactor(((tensor,),), has_bias=has_bias)
+ factor.instantiate_cov_variables()
+ cov = factor.get_cov()
+ self.assertEqual(cov.dtype, dtype)
+ self.assertEqual(final_shape, cov.get_shape().as_list())
+
+ def testFullyConnectedKroneckerFactorInitNoBias(self):
+ for dtype in (dtypes.float32_ref, dtypes.float64_ref):
+ self._testFullyConnectedKroneckerFactorInit(False, [3, 3], dtype=dtype)
+
+ def testFullyConnectedKroneckerFactorInitWithBias(self):
+ for dtype in (dtypes.float32_ref, dtypes.float64_ref):
+ self._testFullyConnectedKroneckerFactorInit(True, [4, 4], dtype=dtype)
+
+ def testMakeCovarianceUpdateOpWithBias(self):
+ with tf_ops.Graph().as_default(), self.test_session() as sess:
+ random_seed.set_random_seed(200)
+ tensor = array_ops.constant([[1., 2.], [3., 4.]], name='a/b/c')
+ factor = ff.FullyConnectedKroneckerFactor(((tensor,),), has_bias=True)
+ factor.instantiate_cov_variables()
+
+ sess.run(tf_variables.global_variables_initializer())
+ new_cov = sess.run(factor.make_covariance_update_op(.5))
+ self.assertAllClose([[3, 3.5, 1], [3.5, 5.5, 1.5], [1, 1.5, 1]], new_cov)
+
+ def testMakeCovarianceUpdateOpNoBias(self):
+ with tf_ops.Graph().as_default(), self.test_session() as sess:
+ random_seed.set_random_seed(200)
+ tensor = array_ops.constant([[1., 2.], [3., 4.]], name='a/b/c')
+ factor = ff.FullyConnectedKroneckerFactor(((tensor,),))
+ factor.instantiate_cov_variables()
+
+ sess.run(tf_variables.global_variables_initializer())
+ new_cov = sess.run(factor.make_covariance_update_op(.5))
+ self.assertAllClose([[3, 3.5], [3.5, 5.5]], new_cov)
+
+
+class ConvFactorTestCase(test.TestCase):
+
+ def assertMatrixRank(self, rank, matrix, atol=1e-5):
+ assert rank <= matrix.shape[0], 'Rank cannot be larger than matrix size.'
+ eigvals = np.linalg.eigvals(matrix)
+ nnz_eigvals = np.sum(eigvals > atol)
+ self.assertEqual(
+ rank,
+ nnz_eigvals,
+ msg=('Found %d of %d expected non-zero eigenvalues: %s.' %
+ (nnz_eigvals, rank, eigvals)))
+
+
+class ConvInputKroneckerFactorTest(ConvFactorTestCase):
+
+ def test3DConvolution(self):
+ with tf_ops.Graph().as_default():
+ batch_size = 1
+ width = 3
+ in_channels = 3**3
+ out_channels = 4
+
+ factor = ff.ConvInputKroneckerFactor(
+ inputs=(random_ops.random_uniform(
+ (batch_size, width, width, width, in_channels), seed=0),),
+ filter_shape=(width, width, width, in_channels, out_channels),
+ padding='SAME',
+ strides=(2, 2, 2),
+ extract_patches_fn='extract_convolution_patches',
+ has_bias=False)
+ factor.instantiate_cov_variables()
+
+ # Ensure shape of covariance matches input size of filter.
+ input_size = in_channels * (width**3)
+ self.assertEqual([input_size, input_size],
+ factor.get_cov().shape.as_list())
+
+ # Ensure cov_update_op doesn't crash.
+ with self.test_session() as sess:
+ sess.run(tf_variables.global_variables_initializer())
+ sess.run(factor.make_covariance_update_op(0.0))
+ cov = sess.run(factor.get_cov())
+
+ # Cov should be rank-8, as the filter will be applied at each corner of
+ # the 4-D cube.
+ self.assertMatrixRank(8, cov)
+
+ def testPointwiseConv2d(self):
+ with tf_ops.Graph().as_default():
+ batch_size = 1
+ width = 3
+ in_channels = 3**2
+ out_channels = 4
+
+ factor = ff.ConvInputKroneckerFactor(
+ inputs=(random_ops.random_uniform(
+ (batch_size, width, width, in_channels), seed=0),),
+ filter_shape=(1, 1, in_channels, out_channels),
+ padding='SAME',
+ strides=(1, 1, 1, 1),
+ extract_patches_fn='extract_pointwise_conv2d_patches',
+ has_bias=False)
+ factor.instantiate_cov_variables()
+
+ # Ensure shape of covariance matches input size of filter.
+ self.assertEqual([in_channels, in_channels],
+ factor.get_cov().shape.as_list())
+
+ # Ensure cov_update_op doesn't crash.
+ with self.test_session() as sess:
+ sess.run(tf_variables.global_variables_initializer())
+ sess.run(factor.make_covariance_update_op(0.0))
+ cov = sess.run(factor.get_cov())
+
+ # Cov should be rank-9, as the filter will be applied at each location.
+ self.assertMatrixRank(9, cov)
+
+ def testStrides(self):
+ with tf_ops.Graph().as_default():
+ batch_size = 1
+ width = 3
+ in_channels = 3**2
+ out_channels = 4
+
+ factor = ff.ConvInputKroneckerFactor(
+ inputs=(random_ops.random_uniform(
+ (batch_size, width, width, in_channels), seed=0),),
+ filter_shape=(1, 1, in_channels, out_channels),
+ padding='SAME',
+ strides=(1, 2, 1, 1),
+ extract_patches_fn='extract_image_patches',
+ has_bias=False)
+ factor.instantiate_cov_variables()
+
+ with self.test_session() as sess:
+ sess.run(tf_variables.global_variables_initializer())
+ sess.run(factor.make_covariance_update_op(0.0))
+ cov = sess.run(factor.get_cov())
+
+ # Cov should be the sum of 3 * 2 = 6 outer products.
+ self.assertMatrixRank(6, cov)
+
+ def testDilationRate(self):
+ with tf_ops.Graph().as_default():
+ batch_size = 1
+ width = 3
+ in_channels = 2
+ out_channels = 4
+
+ factor = ff.ConvInputKroneckerFactor(
+ inputs=(random_ops.random_uniform(
+ (batch_size, width, width, in_channels), seed=0),),
+ filter_shape=(3, 3, in_channels, out_channels),
+ padding='SAME',
+ extract_patches_fn='extract_image_patches',
+ strides=(1, 1, 1, 1),
+ dilation_rate=(1, width, width, 1),
+ has_bias=False)
+ factor.instantiate_cov_variables()
+
+ with self.test_session() as sess:
+ sess.run(tf_variables.global_variables_initializer())
+ sess.run(factor.make_covariance_update_op(0.0))
+ cov = sess.run(factor.get_cov())
+
+ # Cov should be rank = in_channels, as only the center of the filter
+ # receives non-zero input for each input channel.
+ self.assertMatrixRank(in_channels, cov)
+
+ def testConvInputKroneckerFactorInitNoBias(self):
+ with tf_ops.Graph().as_default():
+ tensor = array_ops.ones((64, 1, 2, 3), name='a/b/c')
+ factor = ff.ConvInputKroneckerFactor(
+ inputs=(tensor,),
+ filter_shape=(1, 2, 3, 4),
+ padding='SAME',
+ has_bias=False)
+ factor.instantiate_cov_variables()
+ self.assertEqual([1 * 2 * 3, 1 * 2 * 3],
+ factor.get_cov().get_shape().as_list())
+
+ def testConvInputKroneckerFactorInit(self):
+ with tf_ops.Graph().as_default():
+ tensor = array_ops.ones((64, 1, 2, 3), name='a/b/c')
+ factor = ff.ConvInputKroneckerFactor(
+ (tensor,), filter_shape=(1, 2, 3, 4), padding='SAME', has_bias=True)
+ factor.instantiate_cov_variables()
+ self.assertEqual([1 * 2 * 3 + 1, 1 * 2 * 3 + 1],
+ factor.get_cov().get_shape().as_list())
+
+ def testConvInputKroneckerFactorInitFloat64(self):
+ with tf_ops.Graph().as_default():
+ dtype = dtypes.float64_ref
+ tensor = array_ops.ones((64, 1, 2, 3), name='a/b/c', dtype=dtypes.float64)
+ factor = ff.ConvInputKroneckerFactor(
+ (tensor,), filter_shape=(1, 2, 3, 4), padding='SAME', has_bias=True)
+ factor.instantiate_cov_variables()
+ cov = factor.get_cov()
+ self.assertEqual(cov.dtype, dtype)
+ self.assertEqual([1 * 2 * 3 + 1, 1 * 2 * 3 + 1],
+ cov.get_shape().as_list())
+
+ def testMakeCovarianceUpdateOpWithBias(self):
+ with tf_ops.Graph().as_default(), self.test_session() as sess:
+ input_shape = (2, 1, 1, 1)
+ tensor = array_ops.constant(
+ np.arange(1, 1 + np.prod(input_shape)).reshape(input_shape).astype(
+ np.float32))
+ factor = ff.ConvInputKroneckerFactor(
+ (tensor,), filter_shape=(1, 1, 1, 1), padding='SAME', has_bias=True)
+ factor.instantiate_cov_variables()
+
+ sess.run(tf_variables.global_variables_initializer())
+ new_cov = sess.run(factor.make_covariance_update_op(0.))
+ self.assertAllClose(
+ [
+ [(1. + 4.) / 2., (1. + 2.) / 2.], #
+ [(1. + 2.) / 2., (1. + 1.) / 2.]
+ ], #
+ new_cov)
+
+ def testMakeCovarianceUpdateOpNoBias(self):
+ with tf_ops.Graph().as_default(), self.test_session() as sess:
+ input_shape = (2, 1, 1, 1)
+ tensor = array_ops.constant(
+ np.arange(1, 1 + np.prod(input_shape)).reshape(input_shape).astype(
+ np.float32))
+ factor = ff.ConvInputKroneckerFactor(
+ (tensor,), filter_shape=(1, 1, 1, 1), padding='SAME')
+ factor.instantiate_cov_variables()
+
+ sess.run(tf_variables.global_variables_initializer())
+ new_cov = sess.run(factor.make_covariance_update_op(0.))
+ self.assertAllClose([[(1. + 4.) / 2.]], new_cov)
+
+ def testSubSample(self):
+ with tf_ops.Graph().as_default():
+ patches_1 = array_ops.constant(1, shape=(10, 2))
+ patches_2 = array_ops.constant(1, shape=(10, 8))
+ patches_3 = array_ops.constant(1, shape=(3, 3))
+ patches_1_sub = ff._subsample_for_cov_computation(patches_1)
+ patches_2_sub = ff._subsample_for_cov_computation(patches_2)
+ patches_3_sub = ff._subsample_for_cov_computation(patches_3)
+ patches_1_sub_batch_size = patches_1_sub.shape.as_list()[0]
+ patches_2_sub_batch_size = patches_2_sub.shape.as_list()[0]
+ patches_3_sub_batch_size = patches_3_sub.shape.as_list()[0]
+ self.assertEqual(2, patches_1_sub_batch_size)
+ self.assertEqual(8, patches_2_sub_batch_size)
+ self.assertEqual(3, patches_3_sub_batch_size)
+
+
+class ConvOutputKroneckerFactorTest(ConvFactorTestCase):
+
+ def test3DConvolution(self):
+ with tf_ops.Graph().as_default():
+ batch_size = 1
+ width = 3
+ out_channels = width**3
+
+ factor = ff.ConvOutputKroneckerFactor(outputs_grads=([
+ random_ops.random_uniform(
+ (batch_size, width, width, width, out_channels), seed=0)
+ ],))
+ factor.instantiate_cov_variables()
+
+ with self.test_session() as sess:
+ sess.run(tf_variables.global_variables_initializer())
+ sess.run(factor.make_covariance_update_op(0.0))
+ cov = sess.run(factor.get_cov())
+
+ # Cov should be rank 3^3, as each spatial position donates a rank-1
+ # update.
+ self.assertMatrixRank(width**3, cov)
+
+ def testConvOutputKroneckerFactorInit(self):
+ with tf_ops.Graph().as_default():
+ random_seed.set_random_seed(200)
+ tensor = array_ops.ones((2, 3, 4, 5), name='a/b/c')
+ factor = ff.ConvOutputKroneckerFactor(((tensor,),))
+ factor.instantiate_cov_variables()
+ self.assertEqual([5, 5], factor.get_cov().get_shape().as_list())
+
+ def testConvOutputKroneckerFactorInitFloat64(self):
+ with tf_ops.Graph().as_default():
+ dtype = dtypes.float64_ref
+ random_seed.set_random_seed(200)
+ tensor = array_ops.ones((2, 3, 4, 5), dtype=dtype, name='a/b/c')
+ factor = ff.ConvOutputKroneckerFactor(((tensor,),))
+ factor.instantiate_cov_variables()
+ cov = factor.get_cov()
+ self.assertEqual(cov.dtype, dtype)
+ self.assertEqual([5, 5], cov.get_shape().as_list())
+
+ def testMakeCovarianceUpdateOp(self):
+ with tf_ops.Graph().as_default(), self.test_session() as sess:
+ random_seed.set_random_seed(200)
+ tensor = np.arange(1, 17).reshape(2, 2, 2, 2).astype(np.float32)
+ factor = ff.ConvOutputKroneckerFactor(((array_ops.constant(tensor),),))
+ factor.instantiate_cov_variables()
+
+ sess.run(tf_variables.global_variables_initializer())
+ new_cov = sess.run(factor.make_covariance_update_op(.5))
+ self.assertAllClose([[43, 46.5], [46.5, 51.5]], new_cov)
+
+
+class FullyConnectedMultiKFTest(test.TestCase):
+
+ def testFullyConnectedMultiKFInit(self):
+ with tf_ops.Graph().as_default():
+ random_seed.set_random_seed(200)
+ tensor = array_ops.ones((2, 3), name='a/b/c')
+ factor = ff.FullyConnectedMultiKF(((tensor,),), has_bias=False)
+ factor.instantiate_cov_variables()
+ self.assertEqual([3, 3], factor.get_cov().get_shape().as_list())
+
+ def testFullyConnectedMultiKFInitFloat64(self):
+ with tf_ops.Graph().as_default():
+ dtype = dtypes.float64_ref
+ random_seed.set_random_seed(200)
+ tensor = array_ops.ones((2, 3), dtype=dtype, name='a/b/c')
+ factor = ff.FullyConnectedMultiKF(((tensor,),), has_bias=False)
+ factor.instantiate_cov_variables()
+ cov = factor.get_cov()
+ self.assertEqual(cov.dtype, dtype)
+ self.assertEqual([3, 3], cov.get_shape().as_list())
+
+ def testMakeCovarianceUpdateOpWithBias(self):
+ with tf_ops.Graph().as_default(), self.test_session() as sess:
+ random_seed.set_random_seed(200)
+ tensor = array_ops.constant([[1., 2.], [3., 4.]], name='a/b/c')
+ factor = ff.FullyConnectedMultiKF(((tensor,),), has_bias=True)
+ factor.instantiate_cov_variables()
+
+ sess.run(tf_variables.global_variables_initializer())
+ new_cov = sess.run(factor.make_covariance_update_op(.5))
+ self.assertAllClose([[3, 3.5, 1], [3.5, 5.5, 1.5], [1, 1.5, 1]], new_cov)
+
+ def testMakeCovarianceUpdateOpNoBias(self):
+ with tf_ops.Graph().as_default(), self.test_session() as sess:
+ random_seed.set_random_seed(200)
+ tensor = array_ops.constant([[1., 2.], [3., 4.]], name='a/b/c')
+ factor = ff.FullyConnectedMultiKF(((tensor,),))
+ factor.instantiate_cov_variables()
+
+ sess.run(tf_variables.global_variables_initializer())
+ new_cov = sess.run(factor.make_covariance_update_op(.5))
+ self.assertAllClose([[3, 3.5], [3.5, 5.5]], new_cov)
+
+
+if __name__ == '__main__':
+ test.main()
diff --git a/tensorflow/contrib/kfac/python/kernel_tests/layer_collection_test.py b/tensorflow/contrib/kfac/python/kernel_tests/layer_collection_test.py
new file mode 100644
index 0000000000..cb80fca370
--- /dev/null
+++ b/tensorflow/contrib/kfac/python/kernel_tests/layer_collection_test.py
@@ -0,0 +1,597 @@
+# Copyright 2017 The TensorFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+"""Tests for tf.contrib.kfac.layer_collection."""
+
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+from tensorflow.contrib.kfac.python.ops import fisher_blocks
+from tensorflow.contrib.kfac.python.ops import fisher_factors
+from tensorflow.contrib.kfac.python.ops import layer_collection
+from tensorflow.python.framework import dtypes
+from tensorflow.python.framework import ops
+from tensorflow.python.framework import random_seed
+from tensorflow.python.ops import array_ops
+from tensorflow.python.ops import linalg_ops
+from tensorflow.python.ops import math_ops
+from tensorflow.python.ops import random_ops
+from tensorflow.python.ops import variable_scope
+from tensorflow.python.platform import test
+
+
+class MockFisherBlock(object):
+ """A fake FisherBlock."""
+
+ num_registered_towers = 2
+
+ def __init__(self, name='MockFisherBlock'):
+ self.name = name
+
+ def __eq__(self, other):
+ return isinstance(other, MockFisherBlock) and other.name == self.name
+
+ def __hash__(self):
+ return hash(self.name)
+
+
+class LayerParametersDictTest(test.TestCase):
+
+ def testSetItem(self):
+ """Ensure insertion, contains, retrieval works for supported key types."""
+ with ops.Graph().as_default():
+ lp_dict = layer_collection.LayerParametersDict()
+
+ x = array_ops.constant(0)
+ y0 = array_ops.constant(0)
+ y1 = array_ops.constant(0)
+ z0 = array_ops.constant(0)
+ z1 = array_ops.constant(0)
+ keys = [x, (y0, y1), [z0, z1]]
+ for key in keys:
+ lp_dict[key] = key
+
+ for key in keys:
+ self.assertTrue(key in lp_dict)
+ self.assertEqual(lp_dict[key], key)
+
+ def testSetItemOverlap(self):
+ """Ensure insertion fails if key overlaps with existing key."""
+ with ops.Graph().as_default():
+ lp_dict = layer_collection.LayerParametersDict()
+
+ x = array_ops.constant(0)
+ y = array_ops.constant(0)
+ lp_dict[x] = 'value'
+
+ with self.assertRaises(ValueError):
+ lp_dict[(x, y)] = 'value'
+
+ # Ensure 'y' wasn't inserted.
+ self.assertTrue(x in lp_dict)
+ self.assertFalse(y in lp_dict)
+
+
+class LayerCollectionTest(test.TestCase):
+
+ def testLayerCollectionInit(self):
+ lc = layer_collection.LayerCollection()
+ self.assertEqual(0, len(lc.get_blocks()))
+ self.assertEqual(0, len(lc.get_factors()))
+ self.assertFalse(lc.losses)
+
+ def testRegisterBlocks(self):
+ with ops.Graph().as_default():
+ random_seed.set_random_seed(200)
+ lc = layer_collection.LayerCollection()
+ lc.register_fully_connected(
+ array_ops.constant(1), array_ops.constant(2), array_ops.constant(3))
+ lc.register_fully_connected(
+ array_ops.constant(1),
+ array_ops.constant(2),
+ array_ops.constant(3),
+ approx=layer_collection.APPROX_DIAGONAL_NAME)
+ lc.register_conv2d(
+ params=array_ops.ones((2, 3, 4, 5)),
+ strides=[1, 1, 1, 1],
+ padding='SAME',
+ inputs=array_ops.ones((1, 2, 3, 4)),
+ outputs=array_ops.ones((1, 1, 1, 5)))
+ lc.register_conv2d(
+ params=array_ops.ones((2, 3, 4, 5)),
+ strides=[1, 1, 1, 1],
+ padding='SAME',
+ inputs=array_ops.ones((1, 2, 3, 4)),
+ outputs=array_ops.ones((1, 1, 1, 5)),
+ approx=layer_collection.APPROX_DIAGONAL_NAME)
+ lc.register_separable_conv2d(
+ depthwise_params=array_ops.ones((3, 3, 1, 2)),
+ pointwise_params=array_ops.ones((1, 1, 2, 4)),
+ inputs=array_ops.ones((32, 5, 5, 1)),
+ depthwise_outputs=array_ops.ones((32, 5, 5, 2)),
+ pointwise_outputs=array_ops.ones((32, 5, 5, 4)),
+ strides=[1, 1, 1, 1],
+ padding='SAME')
+ lc.register_convolution(
+ params=array_ops.ones((3, 3, 1, 8)),
+ inputs=array_ops.ones((32, 5, 5, 1)),
+ outputs=array_ops.ones((32, 5, 5, 8)),
+ padding='SAME')
+ lc.register_generic(
+ array_ops.constant(5), 16, approx=layer_collection.APPROX_FULL_NAME)
+ lc.register_generic(
+ array_ops.constant(6),
+ 16,
+ approx=layer_collection.APPROX_DIAGONAL_NAME)
+ lc.register_fully_connected_multi(
+ array_ops.constant(1),
+ (array_ops.constant(2), array_ops.constant(3)),
+ (array_ops.constant(4), array_ops.constant(5)))
+ lc.register_conv2d_multi(
+ params=array_ops.ones((2, 3, 4, 5)),
+ strides=[1, 1, 1, 1],
+ padding='SAME',
+ inputs=(array_ops.ones((1, 2, 3, 4)), array_ops.ones((5, 6, 7, 8))),
+ outputs=(array_ops.ones((1, 1, 1, 5)), array_ops.ones((2, 2, 2, 10))))
+ lc.register_embedding_multi(
+ array_ops.constant((1,)),
+ (array_ops.constant(2), array_ops.constant(3)),
+ (array_ops.constant(4), array_ops.constant(5)))
+
+ self.assertEqual(12, len(lc.get_blocks()))
+
+ def testRegisterBlocksMultipleRegistrations(self):
+ with ops.Graph().as_default():
+ random_seed.set_random_seed(200)
+ lc = layer_collection.LayerCollection()
+ key = array_ops.constant(1)
+ lc.register_fully_connected(key, array_ops.constant(2),
+ array_ops.constant(3))
+ with self.assertRaises(ValueError) as cm:
+ lc.register_generic(key, 16)
+ self.assertIn('already in LayerCollection', str(cm.exception))
+
+ def testRegisterSingleParamNotRegistered(self):
+ x = variable_scope.get_variable('x', initializer=array_ops.constant(1,))
+ lc = layer_collection.LayerCollection()
+ lc.fisher_blocks = {
+ variable_scope.get_variable('y', initializer=array_ops.constant(1,)):
+ '1'
+ }
+ lc.register_block(x, 'foo')
+
+ def testShouldRegisterSingleParamRegistered(self):
+ x = variable_scope.get_variable('x', initializer=array_ops.constant(1,))
+ lc = layer_collection.LayerCollection()
+ lc.fisher_blocks = {x: '1'}
+ with self.assertRaises(ValueError) as cm:
+ lc.register_block(x, 'foo')
+ self.assertIn('already in LayerCollection', str(cm.exception))
+
+ def testRegisterSingleParamRegisteredInTuple(self):
+ x = variable_scope.get_variable('x', initializer=array_ops.constant(1,))
+ y = variable_scope.get_variable('y', initializer=array_ops.constant(1,))
+ lc = layer_collection.LayerCollection()
+ lc.fisher_blocks = {(x, y): '1'}
+ with self.assertRaises(ValueError) as cm:
+ lc.register_block(x, 'foo')
+ self.assertIn('was already registered', str(cm.exception))
+
+ def testRegisterTupleParamNotRegistered(self):
+ x = variable_scope.get_variable('x', initializer=array_ops.constant(1,))
+ y = variable_scope.get_variable('y', initializer=array_ops.constant(1,))
+ lc = layer_collection.LayerCollection()
+ lc.fisher_blocks = {
+ variable_scope.get_variable('z', initializer=array_ops.constant(1,)):
+ '1'
+ }
+
+ lc.register_block((x, y), 'foo')
+ self.assertEqual(set(['1', 'foo']), set(lc.get_blocks()))
+
+ def testRegisterTupleParamRegistered(self):
+ x = variable_scope.get_variable('x', initializer=array_ops.constant(1,))
+ y = variable_scope.get_variable('y', initializer=array_ops.constant(1,))
+ lc = layer_collection.LayerCollection()
+ lc.fisher_blocks = {(x, y): '1'}
+
+ with self.assertRaises(ValueError) as cm:
+ lc.register_block((x, y), 'foo')
+ self.assertIn('already in LayerCollection', str(cm.exception))
+
+ def testRegisterTupleParamRegisteredInSuperset(self):
+ x = variable_scope.get_variable('x', initializer=array_ops.constant(1,))
+ y = variable_scope.get_variable('y', initializer=array_ops.constant(1,))
+ z = variable_scope.get_variable('z', initializer=array_ops.constant(1,))
+ lc = layer_collection.LayerCollection()
+ lc.fisher_blocks = {(x, y, z): '1'}
+
+ with self.assertRaises(ValueError) as cm:
+ lc.register_block((x, y), 'foo')
+ self.assertIn('was already registered', str(cm.exception))
+
+ def testRegisterTupleParamSomeRegistered(self):
+ x = variable_scope.get_variable('x', initializer=array_ops.constant(1,))
+ y = variable_scope.get_variable('y', initializer=array_ops.constant(1,))
+ z = variable_scope.get_variable('z', initializer=array_ops.constant(1,))
+ lc = layer_collection.LayerCollection()
+ lc.fisher_blocks = {x: MockFisherBlock('1'), z: MockFisherBlock('2')}
+
+ with self.assertRaises(ValueError) as cm:
+ lc.register_block((x, y), MockFisherBlock('foo'))
+ self.assertIn('was already registered', str(cm.exception))
+
+ def testRegisterTupleVarSomeRegisteredInOtherTuples(self):
+ x = variable_scope.get_variable('x', initializer=array_ops.constant(1,))
+ y = variable_scope.get_variable('y', initializer=array_ops.constant(1,))
+ z = variable_scope.get_variable('z', initializer=array_ops.constant(1,))
+ w = variable_scope.get_variable('w', initializer=array_ops.constant(1,))
+ lc = layer_collection.LayerCollection()
+ lc.fisher_blocks = {(x, z): '1', (z, w): '2'}
+
+ with self.assertRaises(ValueError) as cm:
+ lc.register_block((x, y), 'foo')
+ self.assertIn('was already registered', str(cm.exception))
+
+ def testRegisterCategoricalPredictiveDistribution(self):
+ with ops.Graph().as_default(), self.test_session() as sess:
+ random_seed.set_random_seed(200)
+ logits = linalg_ops.eye(2)
+
+ lc = layer_collection.LayerCollection()
+ lc.register_categorical_predictive_distribution(logits, seed=200)
+ single_loss = sess.run(lc.total_sampled_loss())
+
+ lc2 = layer_collection.LayerCollection()
+ lc2.register_categorical_predictive_distribution(logits, seed=200)
+ lc2.register_categorical_predictive_distribution(logits, seed=200)
+ double_loss = sess.run(lc2.total_sampled_loss())
+ self.assertAlmostEqual(2 * single_loss, double_loss)
+
+ def testLossFunctionByName(self):
+ """Ensure loss functions can be identified by name."""
+ with ops.Graph().as_default():
+ logits = linalg_ops.eye(2)
+ lc = layer_collection.LayerCollection()
+
+ # Create a new loss function by name.
+ lc.register_categorical_predictive_distribution(logits, name='loss1')
+ self.assertEqual(1, len(lc.towers_by_loss))
+
+ # Add logits to same loss function.
+ lc.register_categorical_predictive_distribution(
+ logits, name='loss1', reuse=True)
+ self.assertEqual(1, len(lc.towers_by_loss))
+
+ # Add another new loss function.
+ lc.register_categorical_predictive_distribution(logits, name='loss2')
+ self.assertEqual(2, len(lc.towers_by_loss))
+
+ def testLossFunctionWithoutName(self):
+ """Ensure loss functions get unique names if 'name' not specified."""
+ with ops.Graph().as_default():
+ logits = linalg_ops.eye(2)
+ lc = layer_collection.LayerCollection()
+
+ # Create a new loss function with default names.
+ lc.register_categorical_predictive_distribution(logits)
+ lc.register_categorical_predictive_distribution(logits)
+ self.assertEqual(2, len(lc.losses))
+
+ def testCategoricalPredictiveDistributionMultipleMinibatches(self):
+ """Ensure multiple minibatches are registered."""
+ with ops.Graph().as_default():
+ batch_size = 3
+ output_size = 2
+ logits = array_ops.zeros([batch_size, output_size])
+ targets = array_ops.ones([batch_size], dtype=dtypes.int32)
+ lc = layer_collection.LayerCollection()
+
+ # Create a new loss function.
+ lc.register_categorical_predictive_distribution(
+ logits, targets=targets, name='loss1')
+
+ # Can add when reuse=True
+ lc.register_categorical_predictive_distribution(
+ logits, targets=targets, name='loss1', reuse=True)
+
+ # Can add when reuse=VARIABLE_SCOPE and reuse=True there.
+ with variable_scope.variable_scope(
+ variable_scope.get_variable_scope(), reuse=True):
+ lc.register_categorical_predictive_distribution(
+ logits,
+ targets=targets,
+ name='loss1',
+ reuse=layer_collection.VARIABLE_SCOPE)
+
+ # Can't add when reuse=False
+ with self.assertRaises(KeyError):
+ lc.register_categorical_predictive_distribution(
+ logits, targets=targets, name='loss1', reuse=False)
+
+ # Can't add when reuse=VARIABLE_SCOPE and reuse=False there.
+ with self.assertRaises(KeyError):
+ lc.register_categorical_predictive_distribution(
+ logits,
+ targets=targets,
+ name='loss1',
+ reuse=layer_collection.VARIABLE_SCOPE)
+
+ self.assertEqual(len(lc.towers_by_loss), 1)
+ # Three successful registrations.
+ self.assertEqual(len(lc.towers_by_loss[0]), 3)
+
+ def testRegisterCategoricalPredictiveDistributionBatchSize1(self):
+ with ops.Graph().as_default():
+ random_seed.set_random_seed(200)
+ logits = random_ops.random_normal((1, 2))
+ lc = layer_collection.LayerCollection()
+
+ lc.register_categorical_predictive_distribution(logits, seed=200)
+
+ def testRegisterCategoricalPredictiveDistributionSpecifiedTargets(self):
+ with ops.Graph().as_default(), self.test_session() as sess:
+ random_seed.set_random_seed(200)
+ logits = array_ops.constant([[1., 2.], [3., 4.]], dtype=dtypes.float32)
+ lc = layer_collection.LayerCollection()
+ targets = array_ops.constant([0, 1], dtype=dtypes.int32)
+
+ lc.register_categorical_predictive_distribution(logits, targets=targets)
+ single_loss = sess.run(lc.total_loss())
+ self.assertAlmostEqual(1.6265233, single_loss)
+
+ def testRegisterNormalPredictiveDistribution(self):
+ with ops.Graph().as_default(), self.test_session() as sess:
+ random_seed.set_random_seed(200)
+ predictions = array_ops.constant(
+ [[1., 2.], [3., 4]], dtype=dtypes.float32)
+
+ lc = layer_collection.LayerCollection()
+ lc.register_normal_predictive_distribution(predictions, 1., seed=200)
+ single_loss = sess.run(lc.total_sampled_loss())
+
+ lc2 = layer_collection.LayerCollection()
+ lc2.register_normal_predictive_distribution(predictions, 1., seed=200)
+ lc2.register_normal_predictive_distribution(predictions, 1., seed=200)
+ double_loss = sess.run(lc2.total_sampled_loss())
+
+ self.assertAlmostEqual(2 * single_loss, double_loss)
+
+ def testRegisterNormalPredictiveDistributionSpecifiedTargets(self):
+ with ops.Graph().as_default(), self.test_session() as sess:
+ random_seed.set_random_seed(200)
+ predictions = array_ops.constant(
+ [[1., 2.], [3., 4.]], dtype=dtypes.float32)
+ lc = layer_collection.LayerCollection()
+ targets = array_ops.constant([[3., 1.], [4., 2.]], dtype=dtypes.float32)
+
+ lc.register_normal_predictive_distribution(
+ predictions, 2.**2, targets=targets)
+ single_loss = sess.run(lc.total_loss())
+ self.assertAlmostEqual(7.6983433, single_loss)
+
+ def ensureLayerReuseWorks(self, register_fn):
+ """Ensure the 'reuse' keyword argument function as intended.
+
+ Args:
+ register_fn: function for registering a layer. Arguments are
+ layer_collection, reuse, and approx.
+ """
+ # Fails on second if reuse=False.
+ lc = layer_collection.LayerCollection()
+ register_fn(lc)
+ with self.assertRaises(ValueError):
+ register_fn(lc, reuse=False)
+
+ # Succeeds on second if reuse=True.
+ lc = layer_collection.LayerCollection()
+ register_fn(lc)
+ register_fn(lc, reuse=True)
+
+ # Fails on second if reuse=VARIABLE_SCOPE and no variable reuse.
+ lc = layer_collection.LayerCollection()
+ register_fn(lc)
+ with self.assertRaises(ValueError):
+ register_fn(lc, reuse=layer_collection.VARIABLE_SCOPE)
+
+ # Succeeds on second if reuse=VARIABLE_SCOPE and variable reuse.
+ lc = layer_collection.LayerCollection()
+ register_fn(lc)
+ with variable_scope.variable_scope(
+ variable_scope.get_variable_scope(), reuse=True):
+ register_fn(lc, reuse=layer_collection.VARIABLE_SCOPE)
+
+ # Fails if block type changes.
+ lc = layer_collection.LayerCollection()
+ register_fn(lc, approx=layer_collection.APPROX_KRONECKER_NAME)
+ with self.assertRaises(ValueError):
+ register_fn(lc, approx=layer_collection.APPROX_DIAGONAL_NAME, reuse=True)
+
+ # Fails if reuse requested but no FisherBlock exists.
+ lc = layer_collection.LayerCollection()
+ with self.assertRaises(KeyError):
+ register_fn(lc, reuse=True)
+
+ def testRegisterFullyConnectedReuse(self):
+ """Ensure the 'reuse' works with register_fully_connected."""
+ with ops.Graph().as_default():
+ inputs = array_ops.ones([2, 10])
+ outputs = array_ops.zeros([2, 5])
+ params = (
+ variable_scope.get_variable('w', [10, 5]), #
+ variable_scope.get_variable('b', [5]))
+
+ def register_fn(lc, **kwargs):
+ lc.register_fully_connected(
+ params=params, inputs=inputs, outputs=outputs, **kwargs)
+
+ self.ensureLayerReuseWorks(register_fn)
+
+ def testRegisterConv2dReuse(self):
+ """Ensure the 'reuse' works with register_conv2d."""
+ with ops.Graph().as_default():
+ inputs = array_ops.ones([2, 5, 5, 10])
+ outputs = array_ops.zeros([2, 5, 5, 3])
+ params = (
+ variable_scope.get_variable('w', [1, 1, 10, 3]), #
+ variable_scope.get_variable('b', [3]))
+
+ def register_fn(lc, **kwargs):
+ lc.register_conv2d(
+ params=params,
+ strides=[1, 1, 1, 1],
+ padding='SAME',
+ inputs=inputs,
+ outputs=outputs,
+ **kwargs)
+
+ self.ensureLayerReuseWorks(register_fn)
+
+ def testReuseWithInvalidRegistration(self):
+ """Invalid registrations shouldn't overwrite existing blocks."""
+ with ops.Graph().as_default():
+ inputs = array_ops.ones([2, 5, 5, 10])
+ outputs = array_ops.zeros([2, 5, 5, 3])
+ w = variable_scope.get_variable('w', [1, 1, 10, 3])
+ b = variable_scope.get_variable('b', [3])
+ lc = layer_collection.LayerCollection()
+ lc.register_fully_connected(w, inputs, outputs)
+ self.assertEqual(lc.fisher_blocks[w].num_registered_towers, 1)
+ with self.assertRaises(KeyError):
+ lc.register_fully_connected((w, b), inputs, outputs, reuse=True)
+ self.assertNotIn((w, b), lc.fisher_blocks)
+ self.assertEqual(lc.fisher_blocks[w].num_registered_towers, 1)
+ lc.register_fully_connected(w, inputs, outputs, reuse=True)
+ self.assertEqual(lc.fisher_blocks[w].num_registered_towers, 2)
+
+ def testMakeOrGetFactor(self):
+ with ops.Graph().as_default():
+ random_seed.set_random_seed(200)
+ lc = layer_collection.LayerCollection()
+ key = array_ops.constant(1)
+ lc.make_or_get_factor(fisher_factors.FullFactor, ((key,), 16))
+ lc.make_or_get_factor(fisher_factors.FullFactor, ((key,), 16))
+ lc.make_or_get_factor(fisher_factors.FullFactor,
+ ((array_ops.constant(2),), 16))
+
+ self.assertEqual(2, len(lc.get_factors()))
+ variables = ops.get_collection(ops.GraphKeys.GLOBAL_VARIABLES)
+ self.assertTrue(
+ all([var.name.startswith('LayerCollection') for var in variables]))
+
+ def testMakeOrGetFactorCustomScope(self):
+ with ops.Graph().as_default():
+ random_seed.set_random_seed(200)
+ scope = 'Foo'
+ lc = layer_collection.LayerCollection(name=scope)
+ key = array_ops.constant(1)
+ lc.make_or_get_factor(fisher_factors.FullFactor, ((key,), 16))
+ lc.make_or_get_factor(fisher_factors.FullFactor, ((key,), 16))
+ lc.make_or_get_factor(fisher_factors.FullFactor,
+ ((array_ops.constant(2),), 16))
+
+ self.assertEqual(2, len(lc.get_factors()))
+ variables = ops.get_collection(ops.GraphKeys.GLOBAL_VARIABLES)
+ self.assertTrue(all([var.name.startswith(scope) for var in variables]))
+
+ def testIdentifyLinkedParametersSomeRegisteredInOtherTuples(self):
+ x = variable_scope.get_variable('x', shape=())
+ y = variable_scope.get_variable('y', shape=())
+ z = variable_scope.get_variable('z', shape=())
+ lc = layer_collection.LayerCollection()
+ lc.define_linked_parameters((x, y))
+
+ with self.assertRaises(ValueError):
+ lc.define_linked_parameters((x, z))
+
+ def testIdentifySubsetPreviouslyRegisteredTensor(self):
+ x = variable_scope.get_variable('x', shape=())
+ y = variable_scope.get_variable('y', shape=())
+ lc = layer_collection.LayerCollection()
+ lc.define_linked_parameters((x, y))
+
+ with self.assertRaises(ValueError):
+ lc.define_linked_parameters(x)
+
+ def testSpecifyApproximation(self):
+ w_0 = variable_scope.get_variable('w_0', [10, 10])
+ w_1 = variable_scope.get_variable('w_1', [10, 10])
+
+ b_0 = variable_scope.get_variable('b_0', [10])
+ b_1 = variable_scope.get_variable('b_1', [10])
+
+ x_0 = array_ops.placeholder(dtypes.float32, shape=(32, 10))
+ x_1 = array_ops.placeholder(dtypes.float32, shape=(32, 10))
+
+ pre_bias_0 = math_ops.matmul(x_0, w_0)
+ pre_bias_1 = math_ops.matmul(x_1, w_1)
+
+ # Build the fully connected layers in the graph.
+ pre_bias_0 + b_0 # pylint: disable=pointless-statement
+ pre_bias_1 + b_1 # pylint: disable=pointless-statement
+
+ lc = layer_collection.LayerCollection()
+ lc.define_linked_parameters(
+ w_0, approximation=layer_collection.APPROX_DIAGONAL_NAME)
+ lc.define_linked_parameters(
+ w_1, approximation=layer_collection.APPROX_DIAGONAL_NAME)
+ lc.define_linked_parameters(
+ b_0, approximation=layer_collection.APPROX_FULL_NAME)
+ lc.define_linked_parameters(
+ b_1, approximation=layer_collection.APPROX_FULL_NAME)
+
+ lc.register_fully_connected(w_0, x_0, pre_bias_0)
+ lc.register_fully_connected(
+ w_1, x_1, pre_bias_1, approx=layer_collection.APPROX_KRONECKER_NAME)
+ self.assertIsInstance(lc.fisher_blocks[w_0],
+ fisher_blocks.FullyConnectedDiagonalFB)
+ self.assertIsInstance(lc.fisher_blocks[w_1],
+ fisher_blocks.FullyConnectedKFACBasicFB)
+
+ lc.register_generic(b_0, batch_size=1)
+ lc.register_generic(
+ b_1, batch_size=1, approx=layer_collection.APPROX_DIAGONAL_NAME)
+ self.assertIsInstance(lc.fisher_blocks[b_0], fisher_blocks.FullFB)
+ self.assertIsInstance(lc.fisher_blocks[b_1], fisher_blocks.NaiveDiagonalFB)
+
+ def testDefaultLayerCollection(self):
+ with ops.Graph().as_default():
+ # Can't get default if there isn't one set.
+ with self.assertRaises(ValueError):
+ layer_collection.get_default_layer_collection()
+
+ # Can't set default twice.
+ lc = layer_collection.LayerCollection()
+ layer_collection.set_default_layer_collection(lc)
+ with self.assertRaises(ValueError):
+ layer_collection.set_default_layer_collection(lc)
+
+ # Same as one set.
+ self.assertTrue(lc is layer_collection.get_default_layer_collection())
+
+ # Can set to None.
+ layer_collection.set_default_layer_collection(None)
+ with self.assertRaises(ValueError):
+ layer_collection.get_default_layer_collection()
+
+ # as_default() is the same as setting/clearing.
+ with lc.as_default():
+ self.assertTrue(lc is layer_collection.get_default_layer_collection())
+ with self.assertRaises(ValueError):
+ layer_collection.get_default_layer_collection()
+
+
+if __name__ == '__main__':
+ test.main()
diff --git a/tensorflow/contrib/kfac/python/kernel_tests/loss_functions_test.py b/tensorflow/contrib/kfac/python/kernel_tests/loss_functions_test.py
new file mode 100644
index 0000000000..c00af5593f
--- /dev/null
+++ b/tensorflow/contrib/kfac/python/kernel_tests/loss_functions_test.py
@@ -0,0 +1,190 @@
+# Copyright 2017 The TensorFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+"""Tests for tf.contrib.kfac.loss_functions."""
+
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+import numpy as np
+
+from tensorflow.contrib.kfac.python.ops import loss_functions
+from tensorflow.python.framework import constant_op
+from tensorflow.python.framework import ops
+from tensorflow.python.ops import array_ops
+from tensorflow.python.platform import test
+
+
+class InsertSliceInZerosTest(test.TestCase):
+
+ def testBadShape(self):
+ bad_shaped_ones = array_ops.ones(shape=[1, 3]) # n.b. shape[1] != 1
+ with self.assertRaises(ValueError):
+ loss_functions.insert_slice_in_zeros(bad_shaped_ones, 1, 42, 17)
+
+ def test3d(self):
+ input_tensor = constant_op.constant([[[1, 2]], [[3, 4]]])
+ expected_output_array = [[[1, 2], [0, 0]], [[3, 4], [0, 0]]]
+ op = loss_functions.insert_slice_in_zeros(input_tensor, 1, 2, 0)
+ with self.test_session() as sess:
+ actual_output_array = sess.run(op)
+ self.assertAllEqual(expected_output_array, actual_output_array)
+
+
+class CategoricalLogitsNegativeLogProbLossTest(test.TestCase):
+
+ def testSample(self):
+ """Ensure samples can be drawn."""
+ with ops.Graph().as_default(), self.test_session() as sess:
+ logits = np.asarray([
+ [0., 0., 0.], #
+ [1., -1., 0.]
+ ]).astype(np.float32)
+ loss = loss_functions.CategoricalLogitsNegativeLogProbLoss(
+ array_ops.constant(logits))
+ sample = loss.sample(42)
+ sample = sess.run(sample)
+ self.assertEqual(sample.shape, (2,))
+
+ def testEvaluateOnTargets(self):
+ """Ensure log probability can be evaluated correctly."""
+ with ops.Graph().as_default(), self.test_session() as sess:
+ logits = np.asarray([
+ [0., 0., 0.], #
+ [1., -1., 0.]
+ ]).astype(np.float32)
+ targets = np.asarray([2, 1]).astype(np.int32)
+ loss = loss_functions.CategoricalLogitsNegativeLogProbLoss(
+ array_ops.constant(logits), targets=array_ops.constant(targets))
+ neg_log_prob = loss.evaluate()
+ neg_log_prob = sess.run(neg_log_prob)
+
+ # Calculate explicit log probability of targets.
+ probs = np.exp(logits) / np.sum(np.exp(logits), axis=1, keepdims=True)
+ log_probs = np.log([
+ probs[0, targets[0]], #
+ probs[1, targets[1]]
+ ])
+ expected_log_prob = np.sum(log_probs)
+
+ self.assertAllClose(neg_log_prob, -expected_log_prob)
+
+ def testEvaluateOnSample(self):
+ """Ensure log probability of a sample can be drawn."""
+ with ops.Graph().as_default(), self.test_session() as sess:
+ logits = np.asarray([
+ [0., 0., 0.], #
+ [1., -1., 0.]
+ ]).astype(np.float32)
+ loss = loss_functions.CategoricalLogitsNegativeLogProbLoss(
+ array_ops.constant(logits))
+ neg_log_prob = loss.evaluate_on_sample(42)
+
+ # Simply ensure this doesn't crash. As the output is random, it's
+ # difficult to say if the output is correct or not...
+ neg_log_prob = sess.run(neg_log_prob)
+
+ def testMultiplyFisherSingleVector(self):
+ with ops.Graph().as_default(), self.test_session() as sess:
+ logits = np.array([1., 2., 3.])
+ loss = loss_functions.CategoricalLogitsNegativeLogProbLoss(logits)
+
+ # the LossFunction.multiply_fisher docstring only says it supports the
+ # case where the vector is the same shape as the input natural parameters
+ # (i.e. the logits here), but here we also test leading dimensions
+ vector = np.array([1., 2., 3.])
+ vectors = [vector, vector.reshape(1, -1), np.stack([vector] * 4)]
+
+ probs = np.exp(logits - np.logaddexp.reduce(logits))
+ fisher = np.diag(probs) - np.outer(probs, probs)
+
+ for vector in vectors:
+ result = loss.multiply_fisher(vector)
+ expected_result = np.dot(vector, fisher)
+ self.assertAllClose(expected_result, sess.run(result))
+
+ def testMultiplyFisherBatch(self):
+ with ops.Graph().as_default(), self.test_session() as sess:
+ logits = np.array([[1., 2., 3.], [4., 6., 8.]])
+ loss = loss_functions.CategoricalLogitsNegativeLogProbLoss(logits)
+
+ vector = np.array([[1., 2., 3.], [5., 3., 1.]])
+
+ na = np.newaxis
+ probs = np.exp(logits - np.logaddexp.reduce(logits, axis=-1,
+ keepdims=True))
+ fishers = probs[..., na] * np.eye(3) - probs[..., na] * probs[..., na, :]
+
+ result = loss.multiply_fisher(vector)
+ expected_result = np.matmul(vector[..., na, :], fishers)[..., 0, :]
+ self.assertEqual(sess.run(result).shape, logits.shape)
+ self.assertAllClose(expected_result, sess.run(result))
+
+
+class OnehotCategoricalLogitsNegativeLogProbLossTest(test.TestCase):
+
+ def testSample(self):
+ """Ensure samples can be drawn."""
+ with ops.Graph().as_default(), self.test_session() as sess:
+ logits = np.asarray([
+ [0., 0., 0.], #
+ [1., -1., 0.]
+ ]).astype(np.float32)
+ loss = loss_functions.OnehotCategoricalLogitsNegativeLogProbLoss(
+ array_ops.constant(logits))
+ sample = loss.sample(42)
+ sample = sess.run(sample)
+ self.assertEqual(sample.shape, (2, 3))
+
+ def testEvaluateOnTargets(self):
+ """Ensure log probability can be evaluated correctly."""
+ with ops.Graph().as_default(), self.test_session() as sess:
+ logits = np.asarray([
+ [0., 0., 0.], #
+ [1., -1., 0.]
+ ]).astype(np.float32)
+ targets = np.asarray([2, 1]).astype(np.int32)
+ loss = loss_functions.OnehotCategoricalLogitsNegativeLogProbLoss(
+ array_ops.constant(logits), targets=array_ops.one_hot(targets, 3))
+ neg_log_prob = loss.evaluate()
+ neg_log_prob = sess.run(neg_log_prob)
+
+ # Calculate explicit log probability of targets.
+ probs = np.exp(logits) / np.sum(np.exp(logits), axis=1, keepdims=True)
+ log_probs = np.log([
+ probs[0, targets[0]], #
+ probs[1, targets[1]]
+ ])
+ expected_log_prob = np.sum(log_probs)
+
+ self.assertAllClose(neg_log_prob, -expected_log_prob)
+
+ def testEvaluateOnSample(self):
+ """Ensure log probability of a sample can be drawn."""
+ with ops.Graph().as_default(), self.test_session() as sess:
+ logits = np.asarray([
+ [0., 0., 0.], #
+ [1., -1., 0.]
+ ]).astype(np.float32)
+ loss = loss_functions.OnehotCategoricalLogitsNegativeLogProbLoss(
+ array_ops.constant(logits))
+ neg_log_prob = loss.evaluate_on_sample(42)
+
+ # Simply ensure this doesn't crash. As the output is random, it's
+ # difficult to say if the output is correct or not...
+ neg_log_prob = sess.run(neg_log_prob)
+
+if __name__ == "__main__":
+ test.main()
diff --git a/tensorflow/contrib/kfac/python/kernel_tests/op_queue_test.py b/tensorflow/contrib/kfac/python/kernel_tests/op_queue_test.py
new file mode 100644
index 0000000000..b20a70e4ca
--- /dev/null
+++ b/tensorflow/contrib/kfac/python/kernel_tests/op_queue_test.py
@@ -0,0 +1,50 @@
+# Copyright 2017 The TensorFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+"""Tests for tf.contrib.kfac.op_queue."""
+
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+from tensorflow.contrib.kfac.python.ops import op_queue
+from tensorflow.python.framework import ops as tf_ops
+from tensorflow.python.ops import math_ops
+from tensorflow.python.platform import test
+
+
+class OpQueueTest(test.TestCase):
+
+ def testNextOp(self):
+ """Ensures all ops get selected eventually."""
+ with tf_ops.Graph().as_default():
+ ops = [
+ math_ops.add(1, 2),
+ math_ops.subtract(1, 2),
+ math_ops.reduce_mean([1, 2]),
+ ]
+ queue = op_queue.OpQueue(ops, seed=0)
+
+ with self.test_session() as sess:
+ # Ensure every inv update op gets selected.
+ selected_ops = set([queue.next_op(sess) for _ in ops])
+ self.assertEqual(set(ops), set(selected_ops))
+
+ # Ensure additional calls don't create any new ops.
+ selected_ops.add(queue.next_op(sess))
+ self.assertEqual(set(ops), set(selected_ops))
+
+
+if __name__ == "__main__":
+ test.main()
diff --git a/tensorflow/contrib/kfac/python/kernel_tests/optimizer_test.py b/tensorflow/contrib/kfac/python/kernel_tests/optimizer_test.py
new file mode 100644
index 0000000000..560a9b0b42
--- /dev/null
+++ b/tensorflow/contrib/kfac/python/kernel_tests/optimizer_test.py
@@ -0,0 +1,219 @@
+# Copyright 2017 The TensorFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+"""Tests for tf.contrib.kfac.optimizer."""
+
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+import numpy as np
+
+from tensorflow.contrib.kfac.python.ops import fisher_factors as ff
+from tensorflow.contrib.kfac.python.ops import layer_collection as lc
+from tensorflow.contrib.kfac.python.ops import optimizer
+from tensorflow.python.framework import ops
+from tensorflow.python.ops import array_ops
+from tensorflow.python.ops import init_ops
+from tensorflow.python.ops import math_ops
+from tensorflow.python.ops import nn
+from tensorflow.python.ops import variable_scope
+from tensorflow.python.ops import variables as tf_variables
+from tensorflow.python.platform import test
+
+
+# We need to set these constants since the numerical values used in the tests
+# were chosen when these used to be the defaults.
+ff.set_global_constants(init_covariances_at_zero=False,
+ zero_debias=False,
+ init_inverses_at_zero=False)
+
+
+def dummy_layer_collection():
+ lcoll = lc.LayerCollection()
+ dummy = array_ops.constant([1., 2.])
+ lcoll.register_categorical_predictive_distribution(logits=dummy)
+ return lcoll
+
+
+class OptimizerTest(test.TestCase):
+
+ def testOptimizerInitInvalidMomentumRegistration(self):
+ with self.assertRaises(ValueError):
+ optimizer.KfacOptimizer(
+ 0.1, 0.2, 0.3, lc.LayerCollection(), momentum_type='foo')
+
+ def testOptimizerInit(self):
+ with ops.Graph().as_default():
+ layer_collection = lc.LayerCollection()
+
+ inputs = array_ops.ones((2, 1)) * 2
+ weights_val = np.ones((1, 1), dtype=np.float32) * 3.
+ weights = variable_scope.get_variable(
+ 'w', initializer=array_ops.constant(weights_val))
+ bias = variable_scope.get_variable(
+ 'b', initializer=init_ops.zeros_initializer(), shape=(1, 1))
+ output = math_ops.matmul(inputs, weights) + bias
+
+ layer_collection.register_fully_connected((weights, bias), inputs, output)
+
+ logits = math_ops.tanh(output)
+ targets = array_ops.constant([[0.], [1.]])
+ output = math_ops.reduce_mean(
+ nn.softmax_cross_entropy_with_logits(logits=logits, labels=targets))
+
+ layer_collection.register_categorical_predictive_distribution(logits)
+
+ optimizer.KfacOptimizer(
+ 0.1,
+ 0.2,
+ 0.3,
+ layer_collection,
+ momentum=0.5,
+ momentum_type='regular')
+
+ def testSquaredFisherNorm(self):
+ with ops.Graph().as_default(), self.test_session() as sess:
+ grads_and_vars = [(array_ops.constant([[1., 2.], [3., 4.]]), None),
+ (array_ops.constant([[2., 3.], [4., 5.]]), None)]
+ pgrads_and_vars = [(array_ops.constant([[3., 4.], [5., 6.]]), None),
+ (array_ops.constant([[7., 8.], [9., 10.]]), None)]
+ opt = optimizer.KfacOptimizer(0.1, 0.2, 0.3, dummy_layer_collection())
+ sq_norm = opt._squared_fisher_norm(grads_and_vars, pgrads_and_vars)
+ self.assertAlmostEqual(174., sess.run(sq_norm), places=5)
+
+ def testUpdateClipCoeff(self):
+ with ops.Graph().as_default(), self.test_session() as sess:
+ grads_and_vars = [(array_ops.constant([[1., 2.], [3., 4.]]), None),
+ (array_ops.constant([[2., 3.], [4., 5.]]), None)]
+ pgrads_and_vars = [(array_ops.constant([[3., 4.], [5., 6.]]), None),
+ (array_ops.constant([[7., 8.], [9., 10.]]), None)]
+ lrate = 0.1
+
+ # Note: without rescaling, the squared Fisher norm of the update
+ # is 1.74
+
+ # If the update already satisfies the norm constraint, there should
+ # be no rescaling.
+ opt = optimizer.KfacOptimizer(
+ lrate, 0.2, 0.3, dummy_layer_collection(), norm_constraint=10.)
+ coeff = opt._update_clip_coeff(grads_and_vars, pgrads_and_vars)
+ self.assertAlmostEqual(1., sess.run(coeff), places=5)
+
+ # If the update violates the constraint, it should be rescaled to
+ # be on the constraint boundary.
+ opt = optimizer.KfacOptimizer(
+ lrate, 0.2, 0.3, dummy_layer_collection(), norm_constraint=0.5)
+ coeff = opt._update_clip_coeff(grads_and_vars, pgrads_and_vars)
+ sq_norm_pgrad = opt._squared_fisher_norm(grads_and_vars, pgrads_and_vars)
+ sq_norm_update = lrate**2 * coeff**2 * sq_norm_pgrad
+ self.assertAlmostEqual(0.5, sess.run(sq_norm_update), places=5)
+
+ def testComputeUpdateStepsRegular(self):
+ # TODO(olganw): implement this.
+ pass
+
+ def testComputeUpdateStepsAdam(self):
+ # TODO(olganw): implement this.
+ pass
+
+ def testUpdateVelocities(self):
+ with ops.Graph().as_default(), self.test_session() as sess:
+ layers = lc.LayerCollection()
+ layers.register_categorical_predictive_distribution(
+ array_ops.constant([1.0]))
+ opt = optimizer.KfacOptimizer(
+ 0.1, 0.2, 0.3, layers, momentum=0.5, momentum_type='regular')
+ x = variable_scope.get_variable('x', initializer=array_ops.ones((2, 2)))
+ y = variable_scope.get_variable(
+ 'y', initializer=array_ops.ones((2, 2)) * 2)
+ vec1 = array_ops.ones((2, 2)) * 3
+ vec2 = array_ops.ones((2, 2)) * 4
+
+ model_vars = ops.get_collection(ops.GraphKeys.GLOBAL_VARIABLES)
+ update_op = opt._update_velocities([(vec1, x), (vec2, y)], 0.5)
+ opt_vars = [
+ v for v in ops.get_collection(ops.GraphKeys.GLOBAL_VARIABLES)
+ if v not in model_vars
+ ]
+
+ sess.run(tf_variables.global_variables_initializer())
+ old_opt_vars = sess.run(opt_vars)
+
+ # Optimizer vars start out at 0.
+ for opt_var in old_opt_vars:
+ self.assertAllEqual(sess.run(array_ops.zeros_like(opt_var)), opt_var)
+
+ sess.run(update_op)
+ new_opt_vars = sess.run(opt_vars)
+ # After one update, the velocities are equal to the vectors.
+ for vec, opt_var in zip([vec1, vec2], new_opt_vars):
+ self.assertAllEqual(sess.run(vec), opt_var)
+
+ sess.run(update_op)
+ final_opt_vars = sess.run(opt_vars)
+ for first, second in zip(new_opt_vars, final_opt_vars):
+ self.assertFalse(np.equal(first, second).all())
+
+ def testApplyGradients(self):
+ with ops.Graph().as_default(), self.test_session() as sess:
+ layer_collection = lc.LayerCollection()
+
+ inputs = array_ops.ones((2, 1)) * 2
+ weights_val = np.ones((1, 1), dtype=np.float32) * 3.
+ weights = variable_scope.get_variable(
+ 'w', initializer=array_ops.constant(weights_val))
+ bias = variable_scope.get_variable(
+ 'b', initializer=init_ops.zeros_initializer(), shape=(1, 1))
+ output = math_ops.matmul(inputs, weights) + bias
+
+ layer_collection.register_fully_connected((weights, bias), inputs, output)
+
+ logits = math_ops.tanh(output)
+ targets = array_ops.constant([[0.], [1.]])
+ output = math_ops.reduce_mean(
+ nn.softmax_cross_entropy_with_logits(logits=logits, labels=targets))
+
+ layer_collection.register_categorical_predictive_distribution(logits)
+
+ opt = optimizer.KfacOptimizer(
+ 0.1,
+ 0.2,
+ 0.3,
+ layer_collection,
+ momentum=0.5,
+ momentum_type='regular')
+ (cov_update_thunks,
+ inv_update_thunks) = opt.make_vars_and_create_op_thunks()
+ cov_update_ops = tuple(thunk() for thunk in cov_update_thunks)
+ inv_update_ops = tuple(thunk() for thunk in inv_update_thunks)
+
+ grads_and_vars = opt.compute_gradients(output, [weights, bias])
+ all_vars = [grad_and_var[1] for grad_and_var in grads_and_vars]
+
+ op = opt.apply_gradients(grads_and_vars)
+
+ sess.run(tf_variables.global_variables_initializer())
+ old_vars = sess.run(all_vars)
+ sess.run(cov_update_ops)
+ sess.run(inv_update_ops)
+ sess.run(op)
+ new_vars = sess.run(all_vars)
+
+ for old_var, new_var in zip(old_vars, new_vars):
+ self.assertNotEqual(old_var, new_var)
+
+
+if __name__ == '__main__':
+ test.main()
diff --git a/tensorflow/contrib/kfac/python/kernel_tests/utils_test.py b/tensorflow/contrib/kfac/python/kernel_tests/utils_test.py
new file mode 100644
index 0000000000..2cee01212a
--- /dev/null
+++ b/tensorflow/contrib/kfac/python/kernel_tests/utils_test.py
@@ -0,0 +1,410 @@
+# Copyright 2017 The TensorFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+"""Tests for tf.contrib.kfac.utils."""
+
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+import numpy as np
+import numpy.random as npr
+
+from tensorflow.contrib.kfac.python.ops import utils
+from tensorflow.contrib.tpu.python.tpu import tpu_function
+from tensorflow.python.framework import dtypes
+from tensorflow.python.framework import ops
+from tensorflow.python.framework import random_seed
+from tensorflow.python.ops import array_ops
+from tensorflow.python.ops import linalg_ops
+from tensorflow.python.ops import math_ops
+from tensorflow.python.ops import nn_ops
+from tensorflow.python.ops import random_ops
+from tensorflow.python.ops import variable_scope
+from tensorflow.python.ops import variables
+from tensorflow.python.platform import test
+
+
+class SequenceDictTest(test.TestCase):
+
+ def testSequenceDictInit(self):
+ seq_dict = utils.SequenceDict()
+ self.assertFalse(seq_dict._dict)
+
+ def testSequenceDictInitWithIterable(self):
+ reg_dict = {'a': 'foo', 'b': 'bar'}
+ itr = zip(reg_dict.keys(), reg_dict.values())
+ seq_dict = utils.SequenceDict(itr)
+ self.assertEqual(reg_dict, seq_dict._dict)
+
+ def testGetItemSingleKey(self):
+ seq_dict = utils.SequenceDict({'a': 'foo', 'b': 'bar'})
+ self.assertEqual('foo', seq_dict['a'])
+
+ def testGetItemMultipleKeys(self):
+ seq_dict = utils.SequenceDict({'a': 'foo', 'b': 'bar'})
+ self.assertEqual(['foo', 'bar'], seq_dict[('a', 'b')])
+
+ def testSetItemSingleKey(self):
+ seq_dict = utils.SequenceDict()
+ seq_dict['a'] = 'foo'
+ self.assertEqual([('a', 'foo')], seq_dict.items())
+
+ def testSetItemMultipleKeys(self):
+ seq_dict = utils.SequenceDict()
+ keys = ('a', 'b', 'c')
+ values = ('foo', 'bar', 'baz')
+ seq_dict[keys] = values
+ self.assertItemsEqual(list(zip(keys, values)), seq_dict.items())
+
+
+class SubGraphTest(test.TestCase):
+
+ def testBasicGraph(self):
+ a = array_ops.constant([[1., 2.], [3., 4.]])
+ b = array_ops.constant([[5., 6.], [7., 8.]])
+ c = a + b
+ d = a * b
+ sub_graph = utils.SubGraph((c,))
+ self.assertTrue(sub_graph.is_member(a))
+ self.assertTrue(sub_graph.is_member(b))
+ self.assertTrue(sub_graph.is_member(c))
+ self.assertFalse(sub_graph.is_member(d))
+
+ def testRepeatedAdds(self):
+ a = array_ops.constant([[1., 2.], [3., 4.]])
+ b = array_ops.constant([[5., 6.], [7., 8.]])
+ c = a + b + a # note that a appears twice in this graph
+ sub_graph = utils.SubGraph((c,))
+ self.assertTrue(sub_graph.is_member(a))
+ self.assertTrue(sub_graph.is_member(b))
+ self.assertTrue(sub_graph.is_member(c))
+
+ def testFilterList(self):
+ a = array_ops.constant([[1., 2.], [3., 4.]])
+ b = array_ops.constant([[5., 6.], [7., 8.]])
+ c = a + b
+ d = a * b
+ sub_graph = utils.SubGraph((c,))
+ input_list = [b, d]
+ filtered_list = sub_graph.filter_list(input_list)
+ self.assertEqual(filtered_list, [b])
+
+ def testVariableUses(self):
+ with ops.Graph().as_default():
+ var = variable_scope.get_variable('var', shape=[10, 10])
+ resource_var = variable_scope.get_variable(
+ 'resource_var', shape=[10, 10], use_resource=True)
+ x = array_ops.zeros([3, 10])
+ z0 = math_ops.matmul(x, var) + math_ops.matmul(x, var)
+ z1 = math_ops.matmul(x, resource_var)
+ sub_graph = utils.SubGraph((z0, z1))
+ self.assertEqual(2, sub_graph.variable_uses(var))
+ self.assertEqual(1, sub_graph.variable_uses(resource_var))
+
+
+class UtilsTest(test.TestCase):
+
+ def _fully_connected_layer_params(self):
+ weights_part = array_ops.constant([[1., 2.], [4., 3.]])
+ bias_part = array_ops.constant([1., 2.])
+ return (weights_part, bias_part)
+
+ def _conv_layer_params(self):
+ weights_shape = 2, 2, 3, 4
+ biases_shape = weights_shape[-1:]
+ weights = array_ops.constant(npr.RandomState(0).randn(*weights_shape))
+ biases = array_ops.constant(npr.RandomState(1).randn(*biases_shape))
+ return (weights, biases)
+
+ def testFullyConnectedLayerParamsTupleToMat2d(self):
+ with ops.Graph().as_default(), self.test_session() as sess:
+ random_seed.set_random_seed(200)
+ layer_params = self._fully_connected_layer_params()
+ output = utils.layer_params_to_mat2d(layer_params)
+ self.assertListEqual([3, 2], output.get_shape().as_list())
+ self.assertAllClose(
+ sess.run(output), np.array([[1., 2.], [4., 3.], [1., 2.]]))
+
+ def testFullyConnectedLayerParamsTensorToMat2d(self):
+ with ops.Graph().as_default(), self.test_session() as sess:
+ random_seed.set_random_seed(200)
+ layer_params = self._fully_connected_layer_params()
+ output = utils.layer_params_to_mat2d(layer_params[0])
+ self.assertListEqual([2, 2], output.get_shape().as_list())
+ self.assertAllClose(sess.run(output), np.array([[1., 2.], [4., 3.]]))
+
+ def testConvLayerParamsTupleToMat2d(self):
+ with ops.Graph().as_default():
+ random_seed.set_random_seed(200)
+ layer_params = self._conv_layer_params()
+ output = utils.layer_params_to_mat2d(layer_params)
+ self.assertListEqual([2 * 2 * 3 + 1, 4], output.get_shape().as_list())
+
+ def testKron(self):
+ with ops.Graph().as_default(), self.test_session() as sess:
+ mat1 = np.array([[1., 2.], [3., 4.]])
+ mat2 = np.array([[5., 6.], [7., 8.]])
+ mat1_tf = array_ops.constant(mat1)
+ mat2_tf = array_ops.constant(mat2)
+ ans_tf = sess.run(utils.kronecker_product(mat1_tf, mat2_tf))
+ ans_np = np.kron(mat1, mat2)
+ self.assertAllClose(ans_tf, ans_np)
+
+ def testMat2dToFullyConnectedLayerParamsTuple(self):
+ with ops.Graph().as_default(), self.test_session() as sess:
+ random_seed.set_random_seed(200)
+ vector_template = self._fully_connected_layer_params()
+ mat2d = array_ops.constant([[5., 4.], [3., 2.], [1., 0.]])
+
+ output = sess.run(utils.mat2d_to_layer_params(vector_template, mat2d))
+
+ self.assertIsInstance(output, tuple)
+ self.assertEqual(len(output), 2)
+ a, b = output
+ self.assertAllClose(a, np.array([[5., 4.], [3., 2.]]))
+ self.assertAllClose(b, np.array([1., 0.]))
+
+ def testMat2dToFullyConnectedLayerParamsTensor(self):
+ with ops.Graph().as_default(), self.test_session() as sess:
+ random_seed.set_random_seed(200)
+ vector_template = self._fully_connected_layer_params()[0]
+ mat2d = array_ops.constant([[5., 4.], [3., 2.]])
+
+ output = sess.run(utils.mat2d_to_layer_params(vector_template, mat2d))
+
+ self.assertAllClose(output, np.array([[5., 4.], [3., 2.]]))
+
+ def testTensorsToColumn(self):
+ with ops.Graph().as_default(), self.test_session() as sess:
+ random_seed.set_random_seed(200)
+
+ vector = array_ops.constant(np.array([[0., 1.], [2., 3.]]))
+ output = utils.tensors_to_column(vector)
+ self.assertListEqual([4, 1], output.get_shape().as_list())
+ self.assertAllClose(sess.run(output), np.array([0., 1., 2., 3.])[:, None])
+
+ vector = self._fully_connected_layer_params()
+ output = utils.tensors_to_column(vector)
+ self.assertListEqual([6, 1], output.get_shape().as_list())
+ self.assertAllClose(
+ sess.run(output), np.array([1., 2., 4., 3., 1., 2.])[:, None])
+
+ vector = list(vector)
+ vector.append(array_ops.constant([[6.], [7.], [8.], [9.]]))
+
+ output = utils.tensors_to_column(vector)
+ self.assertListEqual([10, 1], output.get_shape().as_list())
+ self.assertAllClose(
+ sess.run(output),
+ np.array([1., 2., 4., 3., 1., 2., 6., 7., 8., 9.])[:, None])
+
+ def testColumnToTensors(self):
+ with ops.Graph().as_default(), self.test_session() as sess:
+ random_seed.set_random_seed(200)
+
+ vector_template = array_ops.constant(np.array([[0., 1.], [2., 3.]]))
+ colvec = array_ops.constant(np.arange(4.)[:, None])
+ output = sess.run(utils.column_to_tensors(vector_template, colvec))
+ self.assertAllClose(output, np.array([[0., 1.], [2., 3.]]))
+
+ vector_template = self._fully_connected_layer_params()
+ colvec = array_ops.constant(np.arange(6.)[:, None])
+ output = sess.run(utils.column_to_tensors(vector_template, colvec))
+
+ self.assertIsInstance(output, tuple)
+ self.assertEqual(len(output), 2)
+ a, b = output
+ self.assertAllClose(a, np.array([[0., 1.], [2., 3.]]))
+ self.assertAllClose(b, np.array([4., 5.]))
+
+ vector_template = list(vector_template)
+ vector_template.append(array_ops.constant([[6.], [7.], [8.], [9.]]))
+ colvec = array_ops.constant(np.arange(10.)[:, None])
+ output = sess.run(utils.column_to_tensors(vector_template, colvec))
+ self.assertIsInstance(output, tuple)
+ self.assertEqual(len(output), 3)
+ a, b, c = output
+ self.assertAllClose(a, np.array([[0., 1.], [2., 3.]]))
+ self.assertAllClose(b, np.array([4., 5.]))
+ self.assertAllClose(c, np.array([[6.], [7.], [8.], [9.]]))
+
+ def testPosDefInvCholesky(self):
+ with ops.Graph().as_default(), self.test_session() as sess:
+ random_seed.set_random_seed(200)
+ npr.seed(0)
+ square = lambda x: np.dot(x, x.T)
+
+ size = 3
+ x = square(npr.randn(size, size))
+ damp = 0.1
+ identity = linalg_ops.eye(size, dtype=dtypes.float64)
+
+ tf_inv = utils.posdef_inv_cholesky(array_ops.constant(x), identity, damp)
+ np_inv = np.linalg.inv(x + damp * np.eye(size))
+ self.assertAllClose(sess.run(tf_inv), np_inv)
+
+ def testPosDefInvMatrixInverse(self):
+ with ops.Graph().as_default(), self.test_session() as sess:
+ random_seed.set_random_seed(200)
+ npr.seed(0)
+ square = lambda x: np.dot(x, x.T)
+
+ size = 3
+ x = square(npr.randn(size, size))
+ damp = 0.1
+ identity = linalg_ops.eye(size, dtype=dtypes.float64)
+
+ tf_inv = utils.posdef_inv_matrix_inverse(
+ array_ops.constant(x), identity, damp)
+ np_inv = np.linalg.inv(x + damp * np.eye(size))
+ self.assertAllClose(sess.run(tf_inv), np_inv)
+
+ def testCrossReplicaMean(self):
+ """Ensures that cross_replica_mean() executes only when num_shards > 1."""
+ with ops.Graph().as_default():
+ with tpu_function.tpu_shard_context(4):
+ tensor = array_ops.zeros([], dtype=dtypes.float32)
+ mean = utils.cross_replica_mean(tensor)
+ self.assertNotEqual(mean, tensor)
+
+ with ops.Graph().as_default():
+ with tpu_function.tpu_shard_context(1):
+ tensor = array_ops.zeros([], dtype=dtypes.float32)
+ mean = utils.cross_replica_mean(tensor)
+ self.assertEqual(mean, tensor)
+
+ with ops.Graph().as_default():
+ with self.assertRaises(ValueError): # Outside of TPU context.
+ tensor = array_ops.zeros([], dtype=dtypes.float32)
+ mean = utils.cross_replica_mean(tensor)
+
+ def testBatchExecute(self):
+ """Ensure batch_execute runs in a round-robin fashion."""
+
+ def increment_var(var):
+ return lambda: var.assign_add(1)
+
+ with ops.Graph().as_default(), self.test_session() as sess:
+ i = variable_scope.get_variable('i', initializer=0)
+ accumulators = [
+ variable_scope.get_variable('var%d' % j, initializer=0)
+ for j in range(3)
+ ]
+ thunks = [increment_var(var) for var in accumulators]
+ increment_accumulators = utils.batch_execute(i, thunks, 2)
+ increment_i = i.assign_add(1)
+
+ sess.run(variables.global_variables_initializer())
+
+ # Ensure one op per thunk.
+ self.assertEqual(3, len(increment_accumulators))
+
+ # Ensure round-robin execution.
+ values = []
+ for _ in range(5):
+ sess.run(increment_accumulators)
+ sess.run(increment_i)
+ values.append(sess.run(accumulators))
+ self.assertAllClose(
+ [
+ [1, 1, 0], #
+ [2, 1, 1], #
+ [2, 2, 2], #
+ [3, 3, 2], #
+ [4, 3, 3]
+ ],
+ values)
+
+ def testExtractConvolutionPatches(self):
+ with ops.Graph().as_default(), self.test_session() as sess:
+ batch_size = 10
+ image_spatial_shape = [9, 10, 11]
+ in_channels = out_channels = 32
+ kernel_spatial_shape = [5, 3, 3]
+ spatial_strides = [1, 2, 1]
+ spatial_dilation = [1, 1, 1]
+ padding = 'SAME'
+
+ images = random_ops.random_uniform(
+ [batch_size] + image_spatial_shape + [in_channels], seed=0)
+ kernel_shape = kernel_spatial_shape + [in_channels, out_channels]
+ kernel = random_ops.random_uniform(kernel_shape, seed=1)
+
+ # Ensure shape matches expectation.
+ patches = utils.extract_convolution_patches(
+ images,
+ kernel_shape,
+ padding,
+ strides=spatial_strides,
+ dilation_rate=spatial_dilation)
+ result_spatial_shape = (
+ patches.shape.as_list()[1:1 + len(image_spatial_shape)])
+ self.assertEqual(patches.shape.as_list(),
+ [batch_size] + result_spatial_shape +
+ kernel_spatial_shape + [in_channels])
+
+ # Ensure extract...patches() + matmul() and convolution() implementation
+ # give the same answer.
+ outputs = nn_ops.convolution(
+ images,
+ kernel,
+ padding,
+ strides=spatial_strides,
+ dilation_rate=spatial_dilation)
+
+ patches_flat = array_ops.reshape(
+ patches, [-1, np.prod(kernel_spatial_shape) * in_channels])
+ kernel_flat = array_ops.reshape(kernel, [-1, out_channels])
+ outputs_flat = math_ops.matmul(patches_flat, kernel_flat)
+
+ outputs_, outputs_flat_ = sess.run([outputs, outputs_flat])
+ self.assertAllClose(outputs_.flatten(), outputs_flat_.flatten())
+
+ def testExtractPointwiseConv2dPatches(self):
+ with ops.Graph().as_default(), self.test_session() as sess:
+ batch_size = 10
+ image_height = image_width = 8
+ in_channels = out_channels = 3
+ kernel_height = kernel_width = 1
+ strides = [1, 1, 1, 1]
+ padding = 'VALID'
+
+ images = random_ops.random_uniform(
+ [batch_size, image_height, image_width, in_channels], seed=0)
+ kernel_shape = [kernel_height, kernel_width, in_channels, out_channels]
+ kernel = random_ops.random_uniform(kernel_shape, seed=1)
+
+ # Ensure shape matches expectation.
+ patches = utils.extract_pointwise_conv2d_patches(images, kernel_shape)
+ self.assertEqual(patches.shape.as_list(), [
+ batch_size, image_height, image_width, kernel_height, kernel_width,
+ in_channels
+ ])
+
+ # Ensure extract...patches() + matmul() and conv2d() implementation
+ # give the same answer.
+ outputs = nn_ops.conv2d(images, kernel, strides, padding)
+
+ patches_flat = array_ops.reshape(
+ patches, [-1, kernel_height * kernel_width * in_channels])
+ kernel_flat = array_ops.reshape(kernel, [-1, out_channels])
+ outputs_flat = math_ops.matmul(patches_flat, kernel_flat)
+
+ outputs_, outputs_flat_ = sess.run([outputs, outputs_flat])
+ self.assertAllClose(outputs_.flatten(), outputs_flat_.flatten())
+
+
+if __name__ == '__main__':
+ test.main()
diff --git a/tensorflow/contrib/kfac/python/ops/BUILD b/tensorflow/contrib/kfac/python/ops/BUILD
new file mode 100644
index 0000000000..3c01eb65e7
--- /dev/null
+++ b/tensorflow/contrib/kfac/python/ops/BUILD
@@ -0,0 +1,263 @@
+package(default_visibility = [
+ "//tensorflow/contrib/kfac:__pkg__",
+ "//tensorflow/contrib/kfac/python/kernel_tests:__pkg__",
+])
+
+licenses(["notice"]) # Apache 2.0
+
+exports_files(["LICENSE"])
+
+py_library(
+ name = "fisher_blocks",
+ srcs = ["fisher_blocks.py"],
+ srcs_version = "PY2AND3",
+ deps = [
+ ":fisher_factors",
+ ":utils",
+ "//tensorflow/python:array_ops",
+ "//tensorflow/python:math_ops",
+ "@six_archive//:six",
+ ],
+)
+
+py_library(
+ name = "fisher_blocks_lib",
+ srcs = ["fisher_blocks_lib.py"],
+ srcs_version = "PY2AND3",
+ deps = [
+ ":fisher_blocks",
+ "//tensorflow/python:util",
+ ],
+)
+
+py_library(
+ name = "fisher_factors",
+ srcs = ["fisher_factors.py"],
+ srcs_version = "PY2AND3",
+ deps = [
+ ":linear_operator",
+ ":utils",
+ "//tensorflow/python:array_ops",
+ "//tensorflow/python:control_flow_ops",
+ "//tensorflow/python:dtypes",
+ "//tensorflow/python:framework_ops",
+ "//tensorflow/python:init_ops",
+ "//tensorflow/python:linalg_ops",
+ "//tensorflow/python:math_ops",
+ "//tensorflow/python:random_ops",
+ "//tensorflow/python:special_math_ops",
+ "//tensorflow/python:training",
+ "//tensorflow/python:variable_scope",
+ "//tensorflow/python:variables",
+ "//third_party/py/numpy",
+ "@six_archive//:six",
+ ],
+)
+
+py_library(
+ name = "fisher_factors_lib",
+ srcs = ["fisher_factors_lib.py"],
+ srcs_version = "PY2AND3",
+ deps = [
+ ":fisher_factors",
+ "//tensorflow/python:util",
+ ],
+)
+
+py_library(
+ name = "linear_operator",
+ srcs = ["linear_operator.py"],
+ srcs_version = "PY2AND3",
+ deps = [
+ ":utils",
+ "//tensorflow/python:framework_ops",
+ "//tensorflow/python:math_ops",
+ "//tensorflow/python/ops/linalg",
+ "@six_archive//:six",
+ ],
+)
+
+py_library(
+ name = "loss_functions",
+ srcs = ["loss_functions.py"],
+ srcs_version = "PY2AND3",
+ deps = [
+ "//tensorflow/contrib/distributions:distributions_py",
+ "//tensorflow/python:array_ops",
+ "//tensorflow/python:math_ops",
+ "//tensorflow/python:tensor_shape",
+ "//tensorflow/python/ops/distributions",
+ "@six_archive//:six",
+ ],
+)
+
+py_library(
+ name = "loss_functions_lib",
+ srcs = ["loss_functions_lib.py"],
+ srcs_version = "PY2AND3",
+ deps = [
+ ":loss_functions",
+ "//tensorflow/python:util",
+ ],
+)
+
+py_library(
+ name = "curvature_matrix_vector_products",
+ srcs = ["curvature_matrix_vector_products.py"],
+ srcs_version = "PY2AND3",
+ deps = [
+ ":utils",
+ "//tensorflow/python:gradients",
+ "//tensorflow/python:math_ops",
+ "//tensorflow/python:util",
+ ],
+)
+
+py_library(
+ name = "curvature_matrix_vector_products_lib",
+ srcs = ["curvature_matrix_vector_products_lib.py"],
+ srcs_version = "PY2AND3",
+ deps = [
+ ":curvature_matrix_vector_products",
+ "//tensorflow/python:util",
+ ],
+)
+
+py_library(
+ name = "layer_collection",
+ srcs = ["layer_collection.py"],
+ srcs_version = "PY2AND3",
+ deps = [
+ ":fisher_blocks",
+ ":loss_functions",
+ ":utils",
+ "//tensorflow/python:framework_ops",
+ "//tensorflow/python:math_ops",
+ "//tensorflow/python:platform",
+ "//tensorflow/python:util",
+ "//tensorflow/python:variable_scope",
+ "@six_archive//:six",
+ ],
+)
+
+py_library(
+ name = "layer_collection_lib",
+ srcs = ["layer_collection_lib.py"],
+ srcs_version = "PY2AND3",
+ deps = [
+ ":layer_collection",
+ "//tensorflow/python:util",
+ ],
+)
+
+py_library(
+ name = "kfac_optimizer",
+ srcs = [
+ "optimizer.py",
+ ],
+ srcs_version = "PY2AND3",
+ deps = [
+ ":curvature_matrix_vector_products",
+ ":fisher_estimator",
+ "//tensorflow/python:array_ops",
+ "//tensorflow/python:control_flow_ops",
+ "//tensorflow/python:dtypes",
+ "//tensorflow/python:framework_ops",
+ "//tensorflow/python:linalg_ops",
+ "//tensorflow/python:math_ops",
+ "//tensorflow/python:state_ops",
+ "//tensorflow/python:training",
+ "//tensorflow/python:variable_scope",
+ "//tensorflow/python:variables",
+ ],
+)
+
+py_library(
+ name = "kfac_optimizer_lib",
+ srcs = [
+ "optimizer_lib.py",
+ ],
+ srcs_version = "PY2AND3",
+ deps = [
+ ":kfac_optimizer",
+ "//tensorflow/python:util",
+ ],
+)
+
+py_library(
+ name = "fisher_estimator",
+ srcs = [
+ "estimator.py",
+ "placement.py",
+ ],
+ srcs_version = "PY2AND3",
+ deps = [
+ ":utils",
+ "//tensorflow/python:control_flow_ops",
+ "//tensorflow/python:framework_ops",
+ "//tensorflow/python:gradients",
+ "//tensorflow/python:util",
+ "//third_party/py/numpy",
+ "@six_archive//:six",
+ ],
+)
+
+py_library(
+ name = "fisher_estimator_lib",
+ srcs = [
+ "estimator_lib.py",
+ ],
+ srcs_version = "PY2AND3",
+ deps = [
+ ":fisher_estimator",
+ "//tensorflow/python:util",
+ ],
+)
+
+py_library(
+ name = "utils",
+ srcs = ["utils.py"],
+ srcs_version = "PY2AND3",
+ deps = [
+ "//tensorflow/contrib/tpu",
+ "//tensorflow/python:array_ops",
+ "//tensorflow/python:control_flow_ops",
+ "//tensorflow/python:dtypes",
+ "//tensorflow/python:framework_ops",
+ "//tensorflow/python:gradients",
+ "//tensorflow/python:linalg_ops",
+ "//tensorflow/python:math_ops",
+ "//tensorflow/python:random_ops",
+ "//third_party/py/numpy",
+ ],
+)
+
+py_library(
+ name = "utils_lib",
+ srcs = ["utils_lib.py"],
+ srcs_version = "PY2AND3",
+ deps = [
+ ":utils",
+ "//tensorflow/python:util",
+ ],
+)
+
+py_library(
+ name = "op_queue",
+ srcs = ["op_queue.py"],
+ srcs_version = "PY2AND3",
+ deps = [
+ "//tensorflow/contrib/data/python/ops:dataset_ops",
+ "//tensorflow/python:framework_ops",
+ ],
+)
+
+py_library(
+ name = "op_queue_lib",
+ srcs = ["op_queue_lib.py"],
+ srcs_version = "PY2AND3",
+ deps = [
+ ":op_queue",
+ "//tensorflow/python:util",
+ ],
+)
diff --git a/tensorflow/contrib/kfac/python/ops/curvature_matrix_vector_products.py b/tensorflow/contrib/kfac/python/ops/curvature_matrix_vector_products.py
new file mode 100644
index 0000000000..21b5cde9b9
--- /dev/null
+++ b/tensorflow/contrib/kfac/python/ops/curvature_matrix_vector_products.py
@@ -0,0 +1,183 @@
+# Copyright 2017 The TensorFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+"""Curvature matrix-vector multiplication."""
+
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+from tensorflow.contrib.kfac.python.ops import utils
+from tensorflow.python.ops import gradients_impl
+from tensorflow.python.ops import math_ops
+from tensorflow.python.util import nest
+
+
+class CurvatureMatrixVectorProductComputer(object):
+ """Class for computing matrix-vector products for Fishers, GGNs and Hessians.
+
+ In other words we compute M*v where M is the matrix, v is the vector, and
+ * refers to standard matrix/vector multiplication (not element-wise
+ multiplication).
+
+ The matrices are defined in terms of some differential quantity of the total
+ loss function with respect to a provided list of tensors ("wrt_tensors").
+ For example, the Fisher associated with a log-prob loss w.r.t. the
+ parameters.
+
+ The 'vecs' argument to each method are lists of tensors that must be the
+ size as the corresponding ones from "wrt_tensors". They represent
+ the vector being multiplied.
+
+ "factors" of the matrix M are defined as matrices B such that B*B^T = M.
+ Methods that multiply by the factor B take a 'loss_inner_vecs' argument
+ instead of 'vecs', which must be a list of tensors with shapes given by the
+ corresponding XXX_inner_shapes property.
+
+ Note that matrix-vector products are not normalized by the batch size, nor
+ are any damping terms added to the results. These things can be easily
+ applied externally, if desired.
+
+ See for example: www.cs.utoronto.ca/~jmartens/docs/HF_book_chapter.pdf
+ and https://arxiv.org/abs/1412.1193 for more information about the
+ generalized Gauss-Newton, Fisher, etc., and how to compute matrix-vector
+ products.
+ """
+
+ def __init__(self, losses, wrt_tensors):
+ """Create a CurvatureMatrixVectorProductComputer object.
+
+ Args:
+ losses: A list of LossFunction instances whose sum defines the total loss.
+ wrt_tensors: A list of Tensors to compute the differential quantities
+ (defining the matrices) with respect to. See class description for more
+ info.
+ """
+ self._losses = losses
+ self._inputs_to_losses = list(loss.inputs for loss in losses)
+ self._inputs_to_losses_flat = nest.flatten(self._inputs_to_losses)
+ self._wrt_tensors = wrt_tensors
+
+ @property
+ def _total_loss(self):
+ return math_ops.add_n(tuple(loss.evaluate() for loss in self._losses))
+
+ # Jacobian multiplication functions:
+ def _multiply_jacobian(self, vecs):
+ """Multiply vecs by the Jacobian of losses."""
+ # We stop gradients at wrt_tensors to produce partial derivatives (which is
+ # what we want for Jacobians).
+ jacobian_vecs_flat = utils.fwd_gradients(
+ self._inputs_to_losses_flat, self._wrt_tensors, grad_xs=vecs,
+ stop_gradients=self._wrt_tensors)
+ return nest.pack_sequence_as(self._inputs_to_losses, jacobian_vecs_flat)
+
+ def _multiply_jacobian_transpose(self, loss_vecs):
+ """Multiply vecs by the transpose Jacobian of losses."""
+ loss_vecs_flat = nest.flatten(loss_vecs)
+ # We stop gradients at wrt_tensors to produce partial derivatives (which is
+ # what we want for Jacobians).
+ return gradients_impl.gradients(
+ self._inputs_to_losses_flat, self._wrt_tensors, grad_ys=loss_vecs_flat,
+ stop_gradients=self._wrt_tensors)
+
+ # Losses Fisher/Hessian multiplication functions:
+ def _multiply_loss_fisher(self, loss_vecs):
+ """Multiply loss_vecs by Fisher of total loss."""
+ return tuple(
+ loss.multiply_fisher(loss_vec)
+ for loss, loss_vec in zip(self._losses, loss_vecs))
+
+ def _multiply_loss_fisher_factor(self, loss_inner_vecs):
+ """Multiply loss_inner_vecs by factor of Fisher of total loss."""
+ return tuple(
+ loss.multiply_fisher_factor(loss_vec)
+ for loss, loss_vec in zip(self._losses, loss_inner_vecs))
+
+ def _multiply_loss_fisher_factor_transpose(self, loss_vecs):
+ """Multiply loss_vecs by transpose factor of Fisher of total loss."""
+ return tuple(
+ loss.multiply_fisher_factor_transpose(loss_vec)
+ for loss, loss_vec in zip(self._losses, loss_vecs))
+
+ def _multiply_loss_hessian(self, loss_vecs):
+ """Multiply loss_vecs by Hessian of total loss."""
+ return tuple(
+ loss.multiply_hessian(loss_vec)
+ for loss, loss_vec in zip(self._losses, loss_vecs))
+
+ def _multiply_loss_hessian_factor(self, loss_inner_vecs):
+ """Multiply loss_inner_vecs by factor of Hessian of total loss."""
+ return tuple(
+ loss.multiply_hessian_factor(loss_vec)
+ for loss, loss_vec in zip(self._losses, loss_inner_vecs))
+
+ def _multiply_loss_hessian_factor_transpose(self, loss_vecs):
+ """Multiply loss_vecs by transpose factor of Hessian of total loss."""
+ return tuple(
+ loss.multiply_hessian_factor_transpose(loss_vec)
+ for loss, loss_vec in zip(self._losses, loss_vecs))
+
+ # Matrix-vector product functions:
+ def multiply_fisher(self, vecs):
+ """Multiply vecs by Fisher of total loss."""
+ jacobian_vecs = self._multiply_jacobian(vecs)
+ loss_fisher_jacobian_vecs = self._multiply_loss_fisher(jacobian_vecs)
+ return self._multiply_jacobian_transpose(loss_fisher_jacobian_vecs)
+
+ def multiply_fisher_factor_transpose(self, vecs):
+ """Multiply vecs by transpose of factor of Fisher of total loss."""
+ jacobian_vecs = self._multiply_jacobian(vecs)
+ return self._multiply_loss_fisher_factor_transpose(jacobian_vecs)
+
+ def multiply_fisher_factor(self, loss_inner_vecs):
+ """Multiply loss_inner_vecs by factor of Fisher of total loss."""
+ fisher_factor_transpose_vecs = self._multiply_loss_fisher_factor_transpose(
+ loss_inner_vecs)
+ return self._multiply_jacobian_transpose(fisher_factor_transpose_vecs)
+
+ def multiply_hessian(self, vecs):
+ """Multiply vecs by Hessian of total loss."""
+ return gradients_impl.gradients(
+ gradients_impl.gradients(self._total_loss, self._wrt_tensors),
+ self._wrt_tensors,
+ grad_ys=vecs)
+
+ def multiply_generalized_gauss_newton(self, vecs):
+ """Multiply vecs by generalized Gauss-Newton of total loss."""
+ jacobian_vecs = self._multiply_jacobian(vecs)
+ loss_hessian_jacobian_vecs = self._multiply_loss_hessian(jacobian_vecs)
+ return self._multiply_jacobian_transpose(loss_hessian_jacobian_vecs)
+
+ def multiply_generalized_gauss_newton_factor_transpose(self, vecs):
+ """Multiply vecs by transpose of factor of GGN of total loss."""
+ jacobian_vecs = self._multiply_jacobian(vecs)
+ return self._multiply_loss_hessian_factor_transpose(jacobian_vecs)
+
+ def multiply_generalized_gauss_newton_factor(self, loss_inner_vecs):
+ """Multiply loss_inner_vecs by factor of GGN of total loss."""
+ hessian_factor_transpose_vecs = (
+ self._multiply_loss_hessian_factor_transpose(loss_inner_vecs))
+ return self._multiply_jacobian_transpose(hessian_factor_transpose_vecs)
+
+ # Shape properties for multiply_XXX_factor methods:
+ @property
+ def fisher_factor_inner_shapes(self):
+ """Shapes required by multiply_fisher_factor."""
+ return tuple(loss.fisher_factor_inner_shape for loss in self._losses)
+
+ @property
+ def generalized_gauss_newton_factor_inner_shapes(self):
+ """Shapes required by multiply_generalized_gauss_newton_factor."""
+ return tuple(loss.hessian_factor_inner_shape for loss in self._losses)
diff --git a/tensorflow/contrib/kfac/python/ops/curvature_matrix_vector_products_lib.py b/tensorflow/contrib/kfac/python/ops/curvature_matrix_vector_products_lib.py
new file mode 100644
index 0000000000..6e8c6404dc
--- /dev/null
+++ b/tensorflow/contrib/kfac/python/ops/curvature_matrix_vector_products_lib.py
@@ -0,0 +1,30 @@
+# Copyright 2017 The TensorFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+"""Curvature matrix-vector multiplication."""
+
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+# pylint: disable=unused-import,line-too-long,wildcard-import
+from tensorflow.contrib.kfac.python.ops.curvature_matrix_vector_products import *
+from tensorflow.python.util.all_util import remove_undocumented
+# pylint: enable=unused-import,line-too-long,wildcard-import
+
+_allowed_symbols = [
+ 'CurvatureMatrixVectorProductComputer',
+]
+
+remove_undocumented(__name__, allowed_exception_list=_allowed_symbols)
diff --git a/tensorflow/contrib/kfac/python/ops/estimator.py b/tensorflow/contrib/kfac/python/ops/estimator.py
new file mode 100644
index 0000000000..323234c403
--- /dev/null
+++ b/tensorflow/contrib/kfac/python/ops/estimator.py
@@ -0,0 +1,516 @@
+# Copyright 2017 The TensorFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+"""Defines the high-level Fisher estimator class."""
+
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+import abc
+import numpy as np
+import six
+
+from tensorflow.contrib.kfac.python.ops import placement
+from tensorflow.contrib.kfac.python.ops import utils
+from tensorflow.python.framework import ops as tf_ops
+from tensorflow.python.ops import control_flow_ops
+from tensorflow.python.ops import gradients_impl
+from tensorflow.python.ops import variable_scope
+from tensorflow.python.util import nest
+
+
+# The linter is confused.
+# pylint: disable=abstract-class-instantiated
+def make_fisher_estimator(placement_strategy=None, **kwargs):
+ """Creates Fisher estimator instances based on the placement strategy.
+
+ For example if the `placement_strategy` is 'round_robin' then
+ `FisherEstimatorRoundRobin` instance is returned.
+
+ Args:
+ placement_strategy: `string`, Strategy to be used for placing covariance
+ variables, covariance ops and inverse ops. Check
+ `placement.FisherEstimatorRoundRobin` for a concrete example.
+ **kwargs: Arguments to be passed into `FisherEstimator` class initializer.
+
+ Returns:
+ An instance of class which inherits from `FisherEstimator` and the mixin
+ which implements specific placement strategy. See,
+ `FisherEstimatorRoundRobin` which inherits from `FisherEstimator` and
+ `RoundRobinPlacementMixin`.
+
+ Raises:
+ ValueError: If the `placement_strategy` is not equal to 'round_robin'.
+ """
+ if placement_strategy in [None, "round_robin"]:
+ return FisherEstimatorRoundRobin(**kwargs)
+ else:
+ raise ValueError("Unimplemented vars and ops "
+ "placement strategy : {}".format(placement_strategy))
+# pylint: enable=abstract-class-instantiated
+
+
+@six.add_metaclass(abc.ABCMeta)
+class FisherEstimator(object):
+ """Fisher estimator class supporting various approximations of the Fisher.
+
+ This is an abstract base class which does not implement a strategy for
+ placing covariance variables, covariance update ops and inverse update ops.
+ The placement strategies are implemented in `placement.py`. See
+ `FisherEstimatorRoundRobin` for example of a concrete subclass with
+ a round-robin placement strategy.
+ """
+
+ def __init__(self,
+ variables,
+ cov_ema_decay,
+ damping,
+ layer_collection,
+ exps=(-1,),
+ estimation_mode="gradients",
+ colocate_gradients_with_ops=True,
+ name="FisherEstimator",
+ compute_cholesky=False,
+ compute_cholesky_inverse=False):
+ """Create a FisherEstimator object.
+
+ Args:
+ variables: A `list` of variables or `callable` which returns the variables
+ for which to estimate the Fisher. This must match the variables
+ registered in layer_collection (if it is not None).
+ cov_ema_decay: The decay factor used when calculating the covariance
+ estimate moving averages.
+ damping: float. The damping factor used to stabilize training due to
+ errors in the local approximation with the Fisher information matrix,
+ and to regularize the update direction by making it closer to the
+ gradient. (Higher damping means the update looks more like a standard
+ gradient update - see Tikhonov regularization.)
+ layer_collection: The layer collection object, which holds the Fisher
+ blocks, Kronecker factors, and losses associated with the
+ graph.
+ exps: List of floats or ints. These represent the different matrix
+ powers of the approximate Fisher that the FisherEstimator will be able
+ to multiply vectors by. If the user asks for a matrix power other
+ one of these (or 1, which is always supported), there will be a
+ failure. (Default: (-1,))
+ estimation_mode: The type of estimator to use for the Fishers. Can be
+ 'gradients', 'empirical', 'curvature_prop', or 'exact'.
+ (Default: 'gradients'). 'gradients' is the basic estimation approach
+ from the original K-FAC paper. 'empirical' computes the 'empirical'
+ Fisher information matrix (which uses the data's distribution for the
+ targets, as opposed to the true Fisher which uses the model's
+ distribution) and requires that each registered loss have specified
+ targets. 'curvature_propagation' is a method which estimates the
+ Fisher using self-products of random 1/-1 vectors times "half-factors"
+ of the Fisher, as described here: https://arxiv.org/abs/1206.6464 .
+ Finally, 'exact' is the obvious generalization of Curvature
+ Propagation to compute the exact Fisher (modulo any additional
+ diagonal or Kronecker approximations) by looping over one-hot vectors
+ for each coordinate of the output instead of using 1/-1 vectors. It
+ is more expensive to compute than the other three options by a factor
+ equal to the output dimension, roughly speaking.
+ colocate_gradients_with_ops: Whether we should request gradients be
+ colocated with their respective ops. (Default: True)
+ name: A string. A name given to this estimator, which is added to the
+ variable scope when constructing variables and ops.
+ (Default: "FisherEstimator")
+ compute_cholesky: Bool. Whether or not the FisherEstimator will be
+ able to multiply vectors by the Cholesky factor.
+ (Default: False)
+ compute_cholesky_inverse: Bool. Whether or not the FisherEstimator
+ will be able to multiply vectors by the Cholesky factor inverse.
+ (Default: False)
+ Raises:
+ ValueError: If no losses have been registered with layer_collection.
+ """
+ self._variables = variables
+ self._cov_ema_decay = cov_ema_decay
+ self._damping = damping
+ self._estimation_mode = estimation_mode
+ self._layers = layer_collection
+ self._gradient_fns = {
+ "gradients": self._get_grads_lists_gradients,
+ "empirical": self._get_grads_lists_empirical,
+ "curvature_prop": self._get_grads_lists_curvature_prop,
+ "exact": self._get_grads_lists_exact
+ }
+ self._colocate_gradients_with_ops = colocate_gradients_with_ops
+
+ self._made_vars = False
+ self._exps = exps
+ self._compute_cholesky = compute_cholesky
+ self._compute_cholesky_inverse = compute_cholesky_inverse
+
+ self._name = name
+
+ @property
+ def variables(self):
+ if callable(self._variables):
+ return self._variables()
+ else:
+ return self._variables
+
+ @property
+ def damping(self):
+ return self._damping
+
+ @property
+ def blocks(self):
+ """All registered FisherBlocks."""
+ return self._layers.get_blocks()
+
+ @property
+ def factors(self):
+ """All registered FisherFactors."""
+ return self._layers.get_factors()
+
+ @property
+ def name(self):
+ return self._name
+
+ @abc.abstractmethod
+ def make_vars_and_create_op_thunks(self, scope=None):
+ """Make vars and create op thunks with a specific placement strategy.
+
+ For each factor, all of that factor's cov variables and their associated
+ update ops will be placed on a particular device. A new device is chosen
+ for each factor by cycling through list of devices in the cov_devices
+ argument. If cov_devices is None then no explicit device placement occurs.
+
+ An analogous strategy is followed for inverse update ops, with the list of
+ devices being given by the inv_devices argument.
+
+ Inverse variables on the other hand are not placed on any specific device
+ (they will just use the current the device placement context, whatever
+ that happens to be). The idea is that the inverse variable belong where
+ they will be accessed most often, which is the device that actually applies
+ the preconditioner to the gradient. The user will be responsible for setting
+ the device context for this.
+
+ Args:
+ scope: A string or None. If None it will be set to the name of this
+ estimator (given by the name property). All variables will be created,
+ and all thunks will execute, inside of a variable scope of the given
+ name. (Default: None)
+
+ Returns:
+ cov_update_thunks: List of cov update thunks. Corresponds one-to-one with
+ the list of factors given by the "factors" property.
+ inv_update_thunks: List of inv update thunks. Corresponds one-to-one with
+ the list of factors given by the "factors" property.
+ """
+ pass
+
+ def _apply_transformation(self, vecs_and_vars, transform):
+ """Applies an block-wise transformation to the corresponding vectors.
+
+ Args:
+ vecs_and_vars: List of (vector, variable) pairs.
+ transform: A function of the form f(fb, vec), where vec is the vector
+ to transform and fb is its corresponding block in the matrix, that
+ returns the transformed vector.
+
+ Returns:
+ A list of (transformed vector, var) pairs in the same order as
+ vecs_and_vars.
+ """
+
+ vecs = utils.SequenceDict((var, vec) for vec, var in vecs_and_vars)
+
+ trans_vecs = utils.SequenceDict()
+
+ for params, fb in self._layers.fisher_blocks.items():
+ trans_vecs[params] = transform(fb, vecs[params])
+
+ return [(trans_vecs[var], var) for _, var in vecs_and_vars]
+
+ def multiply_inverse(self, vecs_and_vars):
+ """Multiplies the vecs by the corresponding (damped) inverses of the blocks.
+
+ Args:
+ vecs_and_vars: List of (vector, variable) pairs.
+
+ Returns:
+ A list of (transformed vector, var) pairs in the same order as
+ vecs_and_vars.
+ """
+ return self.multiply_matpower(-1, vecs_and_vars)
+
+ def multiply(self, vecs_and_vars):
+ """Multiplies the vectors by the corresponding (damped) blocks.
+
+ Args:
+ vecs_and_vars: List of (vector, variable) pairs.
+
+ Returns:
+ A list of (transformed vector, var) pairs in the same order as
+ vecs_and_vars.
+ """
+ return self.multiply_matpower(1, vecs_and_vars)
+
+ def multiply_matpower(self, exp, vecs_and_vars):
+ """Multiplies the vecs by the corresponding matrix powers of the blocks.
+
+ Args:
+ exp: A float representing the power to raise the blocks by before
+ multiplying it by the vector.
+ vecs_and_vars: List of (vector, variable) pairs.
+
+ Returns:
+ A list of (transformed vector, var) pairs in the same order as
+ vecs_and_vars.
+ """
+ assert exp in self._exps
+
+ fcn = lambda fb, vec: fb.multiply_matpower(vec, exp)
+ return self._apply_transformation(vecs_and_vars, fcn)
+
+ def multiply_cholesky(self, vecs_and_vars, transpose=False):
+ """Multiplies the vecs by the corresponding Cholesky factors.
+
+ Args:
+ vecs_and_vars: List of (vector, variable) pairs.
+ transpose: Bool. If true the Cholesky factors are transposed before
+ multiplying the vecs. (Default: False)
+
+ Returns:
+ A list of (transformed vector, var) pairs in the same order as
+ vecs_and_vars.
+ """
+ assert self._compute_cholesky
+
+ fcn = lambda fb, vec: fb.multiply_cholesky(vec, transpose=transpose)
+ return self._apply_transformation(vecs_and_vars, fcn)
+
+ def multiply_cholesky_inverse(self, vecs_and_vars, transpose=False):
+ """Mults the vecs by the inverses of the corresponding Cholesky factors.
+
+ Note: if you are using Cholesky inverse multiplication to sample from
+ a matrix-variate Gaussian you will want to multiply by the transpose.
+ Let L be the Cholesky factor of F and observe that
+
+ L^-T * L^-1 = (L * L^T)^-1 = F^-1 .
+
+ Thus we want to multiply by L^-T in order to sample from Gaussian with
+ covariance F^-1.
+
+ Args:
+ vecs_and_vars: List of (vector, variable) pairs.
+ transpose: Bool. If true the Cholesky factor inverses are transposed
+ before multiplying the vecs. (Default: False)
+
+ Returns:
+ A list of (transformed vector, var) pairs in the same order as
+ vecs_and_vars.
+ """
+ assert self._compute_cholesky_inverse
+
+ fcn = lambda fb, vec: fb.multiply_cholesky_inverse(vec, transpose=transpose)
+ return self._apply_transformation(vecs_and_vars, fcn)
+
+ def _instantiate_factors(self):
+ """Instantiates FisherFactors' variables.
+
+ Raises:
+ ValueError: If estimation_mode was improperly specified at construction.
+ """
+ blocks = self.blocks
+ tensors_to_compute_grads = [
+ block.tensors_to_compute_grads() for block in blocks
+ ]
+
+ try:
+ grads_lists = self._gradient_fns[self._estimation_mode](
+ tensors_to_compute_grads)
+ except KeyError:
+ raise ValueError("Unrecognized value {} for estimation_mode.".format(
+ self._estimation_mode))
+
+ for grads_list, block in zip(grads_lists, blocks):
+ block.instantiate_factors(grads_list, self.damping)
+
+ def _check_vars_unmade_and_set_made_flag(self):
+ if self._made_vars:
+ raise Exception("Already made variables.")
+ self._made_vars = True
+
+ def made_vars(self):
+ return self._made_vars
+
+ def _register_matrix_functions(self):
+ for block in self.blocks:
+ for exp in self._exps:
+ block.register_matpower(exp)
+ if self._compute_cholesky:
+ block.register_cholesky()
+ if self._compute_cholesky_inverse:
+ block.register_cholesky_inverse()
+
+ def _finalize_layer_collection(self):
+ self._layers.create_subgraph()
+ self._layers.check_registration(self.variables)
+ self._instantiate_factors()
+ self._register_matrix_functions()
+
+ def create_ops_and_vars_thunks(self, scope=None):
+ """Create thunks that make the ops and vars on demand.
+
+ This function returns 4 lists of thunks: cov_variable_thunks,
+ cov_update_thunks, inv_variable_thunks, and inv_update_thunks.
+
+ The length of each list is the number of factors and the i-th element of
+ each list corresponds to the i-th factor (given by the "factors" property).
+
+ Note that the execution of these thunks must happen in a certain
+ partial order. The i-th element of cov_variable_thunks must execute
+ before the i-th element of cov_update_thunks (and also the i-th element
+ of inv_update_thunks). Similarly, the i-th element of inv_variable_thunks
+ must execute before the i-th element of inv_update_thunks.
+
+ TL;DR (oversimplified): Execute the thunks according to the order that
+ they are returned.
+
+ Args:
+ scope: A string or None. If None it will be set to the name of this
+ estimator (given by the name property). All thunks will execute inside
+ of a variable scope of the given name. (Default: None)
+ Returns:
+ cov_variable_thunks: A list of thunks that make the cov variables.
+ cov_update_thunks: A list of thunks that make the cov update ops.
+ inv_variable_thunks: A list of thunks that make the inv variables.
+ inv_update_thunks: A list of thunks that make the inv update ops.
+ """
+ self._check_vars_unmade_and_set_made_flag()
+
+ self._finalize_layer_collection()
+
+ scope = self.name if scope is None else scope
+
+ cov_variable_thunks = [
+ self._create_cov_variable_thunk(factor, scope)
+ for factor in self.factors
+ ]
+ cov_update_thunks = [
+ self._create_cov_update_thunk(factor, scope) for factor in self.factors
+ ]
+ inv_variable_thunks = [
+ self._create_inv_variable_thunk(factor, scope)
+ for factor in self.factors
+ ]
+ inv_update_thunks = [
+ self._create_inv_update_thunk(factor, scope) for factor in self.factors
+ ]
+
+ return (cov_variable_thunks, cov_update_thunks,
+ inv_variable_thunks, inv_update_thunks)
+
+ def _create_cov_variable_thunk(self, factor, scope):
+ """Constructs a covariance variable thunk for a single FisherFactor."""
+
+ def thunk():
+ with variable_scope.variable_scope(scope):
+ return factor.instantiate_cov_variables()
+
+ return thunk
+
+ def _create_cov_update_thunk(self, factor, scope):
+ """Constructs a covariance update thunk for a single FisherFactor."""
+
+ def thunk():
+ with variable_scope.variable_scope(scope):
+ return factor.make_covariance_update_op(self._cov_ema_decay)
+
+ return thunk
+
+ def _create_inv_variable_thunk(self, factor, scope):
+ """Constructs a inverse variable thunk for a single FisherFactor."""
+
+ def thunk():
+ with variable_scope.variable_scope(scope):
+ return factor.instantiate_inv_variables()
+
+ return thunk
+
+ def _create_inv_update_thunk(self, factor, scope):
+ """Constructs an inverse update thunk for a single FisherFactor."""
+
+ def thunk():
+ with variable_scope.variable_scope(scope):
+ return control_flow_ops.group(factor.make_inverse_update_ops())
+
+ return thunk
+
+ def _get_grads_lists_gradients(self, tensors):
+ # Passing in a list of loss values is better than passing in the sum as
+ # the latter creates unnessesary ops on the default device
+ grads_flat = gradients_impl.gradients(
+ self._layers.eval_losses_on_samples(),
+ nest.flatten(tensors),
+ colocate_gradients_with_ops=self._colocate_gradients_with_ops)
+ grads_all = nest.pack_sequence_as(tensors, grads_flat)
+ return tuple((grad,) for grad in grads_all)
+
+ def _get_grads_lists_empirical(self, tensors):
+ # Passing in a list of loss values is better than passing in the sum as
+ # the latter creates unnecessary ops on the default device
+ grads_flat = gradients_impl.gradients(
+ self._layers.eval_losses(),
+ nest.flatten(tensors),
+ colocate_gradients_with_ops=self._colocate_gradients_with_ops)
+ grads_all = nest.pack_sequence_as(tensors, grads_flat)
+ return tuple((grad,) for grad in grads_all)
+
+ def _get_transformed_random_signs(self):
+ transformed_random_signs = []
+ for loss in self._layers.losses:
+ with tf_ops.colocate_with(self._layers.loss_colocation_ops[loss]):
+ transformed_random_signs.append(
+ loss.multiply_fisher_factor(
+ utils.generate_random_signs(loss.fisher_factor_inner_shape)))
+ return transformed_random_signs
+
+ def _get_grads_lists_curvature_prop(self, tensors):
+ loss_inputs = list(loss.inputs for loss in self._layers.losses)
+ transformed_random_signs = self._get_transformed_random_signs()
+ grads_flat = gradients_impl.gradients(
+ nest.flatten(loss_inputs),
+ nest.flatten(tensors),
+ grad_ys=nest.flatten(transformed_random_signs),
+ colocate_gradients_with_ops=self._colocate_gradients_with_ops)
+ grads_all = nest.pack_sequence_as(tensors, grads_flat)
+ return tuple((grad,) for grad in grads_all)
+
+ def _get_grads_lists_exact(self, tensors):
+ """No docstring required."""
+ # Loop over all coordinates of all losses.
+ grads_all = []
+ for loss in self._layers.losses:
+ with tf_ops.colocate_with(self._layers.loss_colocation_ops[loss]):
+ for index in np.ndindex(*loss.fisher_factor_inner_static_shape[1:]):
+ transformed_one_hot = loss.multiply_fisher_factor_replicated_one_hot(
+ index)
+ grads_flat = gradients_impl.gradients(
+ loss.inputs,
+ nest.flatten(tensors),
+ grad_ys=transformed_one_hot,
+ colocate_gradients_with_ops=self._colocate_gradients_with_ops)
+ grads_all.append(nest.pack_sequence_as(tensors, grads_flat))
+ return zip(*grads_all)
+
+
+class FisherEstimatorRoundRobin(placement.RoundRobinPlacementMixin,
+ FisherEstimator):
+ """Fisher estimator which provides round robin device placement strategy."""
+ pass
diff --git a/tensorflow/contrib/kfac/python/ops/estimator_lib.py b/tensorflow/contrib/kfac/python/ops/estimator_lib.py
new file mode 100644
index 0000000000..9c9fef471f
--- /dev/null
+++ b/tensorflow/contrib/kfac/python/ops/estimator_lib.py
@@ -0,0 +1,31 @@
+# Copyright 2017 The TensorFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+"""Defines the high-level Fisher estimator class."""
+
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+# pylint: disable=unused-import,line-too-long,wildcard-import
+from tensorflow.contrib.kfac.python.ops.estimator import *
+from tensorflow.python.util.all_util import remove_undocumented
+# pylint: enable=unused-import,line-too-long,wildcard-import
+
+_allowed_symbols = [
+ 'FisherEstimator',
+ 'make_fisher_estimator',
+]
+
+remove_undocumented(__name__, allowed_exception_list=_allowed_symbols)
diff --git a/tensorflow/contrib/kfac/python/ops/fisher_blocks.py b/tensorflow/contrib/kfac/python/ops/fisher_blocks.py
new file mode 100644
index 0000000000..9fa6eb7dcd
--- /dev/null
+++ b/tensorflow/contrib/kfac/python/ops/fisher_blocks.py
@@ -0,0 +1,1752 @@
+# Copyright 2017 The TensorFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+"""FisherBlock definitions.
+
+This library contains classes for estimating blocks in a model's Fisher
+Information matrix. Suppose one has a model that parameterizes a posterior
+distribution over 'y' given 'x' with parameters 'params', p(y | x, params). Its
+Fisher Information matrix is given by,
+
+ $$F(params) = E[ v(x, y, params) v(x, y, params)^T ]$$
+
+where,
+
+ $$v(x, y, params) = (d / d params) log p(y | x, params)$$
+
+and the expectation is taken with respect to the data's distribution for 'x' and
+the model's posterior distribution for 'y',
+
+ x ~ p(x)
+ y ~ p(y | x, params)
+
+"""
+
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+import abc
+import enum # pylint: disable=g-bad-import-order
+
+import numpy as np
+import six
+
+from tensorflow.contrib.kfac.python.ops import fisher_factors
+from tensorflow.contrib.kfac.python.ops import utils
+from tensorflow.python.framework import ops
+from tensorflow.python.ops import array_ops
+from tensorflow.python.ops import math_ops
+from tensorflow.python.util import nest
+
+# For blocks corresponding to convolutional layers, or any type of block where
+# the parameters can be thought of as being replicated in time or space,
+# we want to adjust the scale of the damping by
+# damping /= num_replications ** NORMALIZE_DAMPING_POWER
+NORMALIZE_DAMPING_POWER = 1.0
+
+# Methods for adjusting damping for FisherBlocks. See
+# compute_pi_adjusted_damping() for details.
+PI_OFF_NAME = "off"
+PI_TRACENORM_NAME = "tracenorm"
+PI_TYPE = PI_TRACENORM_NAME
+
+
+def set_global_constants(normalize_damping_power=None, pi_type=None):
+ """Sets various global constants used by the classes in this module."""
+ global NORMALIZE_DAMPING_POWER
+ global PI_TYPE
+
+ if normalize_damping_power is not None:
+ NORMALIZE_DAMPING_POWER = normalize_damping_power
+
+ if pi_type is not None:
+ PI_TYPE = pi_type
+
+
+def normalize_damping(damping, num_replications):
+ """Normalize damping after adjusting scale by NORMALIZE_DAMPING_POWER."""
+ if NORMALIZE_DAMPING_POWER:
+ return damping / (num_replications ** NORMALIZE_DAMPING_POWER)
+ return damping
+
+
+def compute_pi_tracenorm(left_cov, right_cov):
+ r"""Computes the scalar constant pi for Tikhonov regularization/damping.
+
+ $$\pi = \sqrt{ (trace(A) / dim(A)) / (trace(B) / dim(B)) }$$
+ See section 6.3 of https://arxiv.org/pdf/1503.05671.pdf for details.
+
+ Args:
+ left_cov: A LinearOperator object. The left Kronecker factor "covariance".
+ right_cov: A LinearOperator object. The right Kronecker factor "covariance".
+
+ Returns:
+ The computed scalar constant pi for these Kronecker Factors (as a Tensor).
+ """
+ # Instead of dividing by the dim of the norm, we multiply by the dim of the
+ # other norm. This works out the same in the ratio.
+ left_norm = left_cov.trace() * int(right_cov.domain_dimension)
+ right_norm = right_cov.trace() * int(left_cov.domain_dimension)
+ return math_ops.sqrt(left_norm / right_norm)
+
+
+def compute_pi_adjusted_damping(left_cov, right_cov, damping):
+
+ if PI_TYPE == PI_TRACENORM_NAME:
+ pi = compute_pi_tracenorm(left_cov, right_cov)
+ return (damping * pi, damping / pi)
+
+ elif PI_TYPE == PI_OFF_NAME:
+ return (damping, damping)
+
+
+class PackagedFunc(object):
+ """A Python thunk with a stable ID.
+
+ Enables stable names for lambdas.
+ """
+
+ def __init__(self, func, func_id):
+ """Initializes PackagedFunc.
+
+ Args:
+ func: a zero-arg Python function.
+ func_id: a hashable, function that produces a hashable, or a list/tuple
+ thereof.
+ """
+ self._func = func
+ func_id = func_id if isinstance(func_id, (tuple, list)) else (func_id,)
+ self._func_id = func_id
+
+ def __call__(self):
+ return self._func()
+
+ @property
+ def func_id(self):
+ """A hashable identifier for this function."""
+ return tuple(elt() if callable(elt) else elt for elt in self._func_id)
+
+
+def _package_func(func, func_id):
+ return PackagedFunc(func, func_id)
+
+
+@six.add_metaclass(abc.ABCMeta)
+class FisherBlock(object):
+ """Abstract base class for objects modeling approximate Fisher matrix blocks.
+
+ Subclasses must implement register_matpower, multiply_matpower,
+ instantiate_factors, tensors_to_compute_grads, and num_registered_towers
+ methods.
+ """
+
+ def __init__(self, layer_collection):
+ self._layer_collection = layer_collection
+
+ @abc.abstractmethod
+ def instantiate_factors(self, grads_list, damping):
+ """Creates and registers the component factors of this Fisher block.
+
+ Args:
+ grads_list: A list gradients (each a Tensor or tuple of Tensors) with
+ respect to the tensors returned by tensors_to_compute_grads() that
+ are to be used to estimate the block.
+ damping: The damping factor (float or Tensor).
+ """
+ pass
+
+ @abc.abstractmethod
+ def register_matpower(self, exp):
+ """Registers a matrix power to be computed by the block.
+
+ Args:
+ exp: A float representing the power to raise the block by.
+ """
+ pass
+
+ @abc.abstractmethod
+ def register_cholesky(self):
+ """Registers a Cholesky factor to be computed by the block."""
+ pass
+
+ @abc.abstractmethod
+ def register_cholesky_inverse(self):
+ """Registers an inverse Cholesky factor to be computed by the block."""
+ pass
+
+ def register_inverse(self):
+ """Registers a matrix inverse to be computed by the block."""
+ self.register_matpower(-1)
+
+ @abc.abstractmethod
+ def multiply_matpower(self, vector, exp):
+ """Multiplies the vector by the (damped) matrix-power of the block.
+
+ Args:
+ vector: The vector (a Tensor or tuple of Tensors) to be multiplied.
+ exp: A float representing the power to raise the block by before
+ multiplying it by the vector.
+
+ Returns:
+ The vector left-multiplied by the (damped) matrix-power of the block.
+ """
+ pass
+
+ def multiply_inverse(self, vector):
+ """Multiplies the vector by the (damped) inverse of the block.
+
+ Args:
+ vector: The vector (a Tensor or tuple of Tensors) to be multiplied.
+
+ Returns:
+ The vector left-multiplied by the (damped) inverse of the block.
+ """
+ return self.multiply_matpower(vector, -1)
+
+ def multiply(self, vector):
+ """Multiplies the vector by the (damped) block.
+
+ Args:
+ vector: The vector (a Tensor or tuple of Tensors) to be multiplied.
+
+ Returns:
+ The vector left-multiplied by the (damped) block.
+ """
+ return self.multiply_matpower(vector, 1)
+
+ @abc.abstractmethod
+ def multiply_cholesky(self, vector, transpose=False):
+ """Multiplies the vector by the (damped) Cholesky-factor of the block.
+
+ Args:
+ vector: The vector (a Tensor or tuple of Tensors) to be multiplied.
+ transpose: Bool. If true the Cholesky factor is transposed before
+ multiplying the vector. (Default: False)
+
+ Returns:
+ The vector left-multiplied by the (damped) Cholesky-factor of the block.
+ """
+ pass
+
+ @abc.abstractmethod
+ def multiply_cholesky_inverse(self, vector, transpose=False):
+ """Multiplies vector by the (damped) inverse Cholesky-factor of the block.
+
+ Args:
+ vector: The vector (a Tensor or tuple of Tensors) to be multiplied.
+ transpose: Bool. If true the Cholesky factor inverse is transposed
+ before multiplying the vector. (Default: False)
+ Returns:
+ Vector left-multiplied by (damped) inverse Cholesky-factor of the block.
+ """
+ pass
+
+ @abc.abstractmethod
+ def tensors_to_compute_grads(self):
+ """Returns the Tensor(s) with respect to which this FisherBlock needs grads.
+ """
+ pass
+
+ @abc.abstractproperty
+ def num_registered_towers(self):
+ """Number of towers registered for this FisherBlock.
+
+ Typically equal to the number of towers in a multi-tower setup.
+ """
+ pass
+
+
+class FullFB(FisherBlock):
+ """FisherBlock using a full matrix estimate (no approximations).
+
+ FullFB uses a full matrix estimate (no approximations), and should only ever
+ be used for very low dimensional parameters.
+
+ Note that this uses the naive "square the sum estimator", and so is applicable
+ to any type of parameter in principle, but has very high variance.
+ """
+
+ def __init__(self, layer_collection, params):
+ """Creates a FullFB block.
+
+ Args:
+ layer_collection: The collection of all layers in the K-FAC approximate
+ Fisher information matrix to which this FisherBlock belongs.
+ params: The parameters of this layer (Tensor or tuple of Tensors).
+ """
+ self._batch_sizes = []
+ self._params = params
+
+ super(FullFB, self).__init__(layer_collection)
+
+ def instantiate_factors(self, grads_list, damping):
+ self._damping_func = _package_func(lambda: damping, (damping,))
+
+ self._factor = self._layer_collection.make_or_get_factor(
+ fisher_factors.FullFactor, (grads_list, self._batch_size))
+
+ def register_matpower(self, exp):
+ self._factor.register_matpower(exp, self._damping_func)
+
+ def register_cholesky(self):
+ self._factor.register_cholesky(self._damping_func)
+
+ def register_cholesky_inverse(self):
+ self._factor.register_cholesky_inverse(self._damping_func)
+
+ def _multiply_matrix(self, matrix, vector, transpose=False):
+ vector_flat = utils.tensors_to_column(vector)
+ out_flat = matrix.matmul(vector_flat, adjoint=transpose)
+ return utils.column_to_tensors(vector, out_flat)
+
+ def multiply_matpower(self, vector, exp):
+ matrix = self._factor.get_matpower(exp, self._damping_func)
+ return self._multiply_matrix(matrix, vector)
+
+ def multiply_cholesky(self, vector, transpose=False):
+ matrix = self._factor.get_cholesky(self._damping_func)
+ return self._multiply_matrix(matrix, vector, transpose=transpose)
+
+ def multiply_cholesky_inverse(self, vector, transpose=False):
+ matrix = self._factor.get_cholesky_inverse(self._damping_func)
+ return self._multiply_matrix(matrix, vector, transpose=transpose)
+
+ def full_fisher_block(self):
+ """Explicitly constructs the full Fisher block."""
+ return self._factor.get_cov_as_linear_operator().to_dense()
+
+ def tensors_to_compute_grads(self):
+ return self._params
+
+ def register_additional_tower(self, batch_size):
+ """Register an additional tower.
+
+ Args:
+ batch_size: The batch size, used in the covariance estimator.
+ """
+ self._batch_sizes.append(batch_size)
+
+ @property
+ def num_registered_towers(self):
+ return len(self._batch_sizes)
+
+ @property
+ def _batch_size(self):
+ return math_ops.reduce_sum(self._batch_sizes)
+
+
+@six.add_metaclass(abc.ABCMeta)
+class DiagonalFB(FisherBlock):
+ """A base class for FisherBlocks that use diagonal approximations."""
+
+ def register_matpower(self, exp):
+ # Not needed for this. Matrix powers are computed on demand in the
+ # diagonal case
+ pass
+
+ def register_cholesky(self):
+ # Not needed for this. Cholesky's are computed on demand in the
+ # diagonal case
+ pass
+
+ def register_cholesky_inverse(self):
+ # Not needed for this. Cholesky inverses's are computed on demand in the
+ # diagonal case
+ pass
+
+ def _multiply_matrix(self, matrix, vector):
+ vector_flat = utils.tensors_to_column(vector)
+ out_flat = matrix.matmul(vector_flat)
+ return utils.column_to_tensors(vector, out_flat)
+
+ def multiply_matpower(self, vector, exp):
+ matrix = self._factor.get_matpower(exp, self._damping_func)
+ return self._multiply_matrix(matrix, vector)
+
+ def multiply_cholesky(self, vector, transpose=False):
+ matrix = self._factor.get_cholesky(self._damping_func)
+ return self._multiply_matrix(matrix, vector)
+
+ def multiply_cholesky_inverse(self, vector, transpose=False):
+ matrix = self._factor.get_cholesky_inverse(self._damping_func)
+ return self._multiply_matrix(matrix, vector)
+
+ def full_fisher_block(self):
+ return self._factor.get_cov_as_linear_operator().to_dense()
+
+
+class NaiveDiagonalFB(DiagonalFB):
+ """FisherBlock using a diagonal matrix approximation.
+
+ This type of approximation is generically applicable but quite primitive.
+
+ Note that this uses the naive "square the sum estimator", and so is applicable
+ to any type of parameter in principle, but has very high variance.
+ """
+
+ def __init__(self, layer_collection, params):
+ """Creates a NaiveDiagonalFB block.
+
+ Args:
+ layer_collection: The collection of all layers in the K-FAC approximate
+ Fisher information matrix to which this FisherBlock belongs.
+ params: The parameters of this layer (Tensor or tuple of Tensors).
+ """
+ self._params = params
+ self._batch_sizes = []
+
+ super(NaiveDiagonalFB, self).__init__(layer_collection)
+
+ def instantiate_factors(self, grads_list, damping):
+ self._damping_func = _package_func(lambda: damping, (damping,))
+
+ self._factor = self._layer_collection.make_or_get_factor(
+ fisher_factors.NaiveDiagonalFactor, (grads_list, self._batch_size))
+
+ def tensors_to_compute_grads(self):
+ return self._params
+
+ def register_additional_tower(self, batch_size):
+ """Register an additional tower.
+
+ Args:
+ batch_size: The batch size, used in the covariance estimator.
+ """
+ self._batch_sizes.append(batch_size)
+
+ @property
+ def num_registered_towers(self):
+ return len(self._batch_sizes)
+
+ @property
+ def _batch_size(self):
+ return math_ops.reduce_sum(self._batch_sizes)
+
+
+class InputOutputMultiTower(object):
+ """Mix-in class for blocks with inputs & outputs and multiple mini-batches."""
+
+ def __init__(self, *args, **kwargs):
+ self.__inputs = []
+ self.__outputs = []
+ super(InputOutputMultiTower, self).__init__(*args, **kwargs)
+
+ def _process_data(self, grads_list):
+ """Process data into the format used by the factors.
+
+ This function takes inputs and grads_lists data and processes it into
+ one of the formats expected by the FisherFactor classes (depending on
+ the value of the global configuration variable TOWER_STRATEGY).
+
+ The initial format of self._inputs is expected to be a list of Tensors
+ over towers. Similarly grads_lists is expected to be a list over sources
+ of such lists.
+
+ If TOWER_STRATEGY is "concat", 'inputs' becomes a tuple containing a single
+ tensor (represented as a PartitionedTensor object) equal to the
+ concatenation (across towers) of all of the elements of self._inputs. And
+ similarly grads_list is formatted into a tuple (over sources) of such
+ tensors (also represented as PartitionedTensors).
+
+ If TOWER_STRATEGY is "separate", formatting of inputs and grads_list
+ remains unchanged from the initial format (although possibly converting
+ from lists into tuples).
+
+ Args:
+ grads_list: grads_list in its initial format (see above).
+
+ Returns:
+ inputs: self._inputs transformed into the appropriate format (see
+ above).
+ grads_list: grads_list transformed into the appropriate format (see
+ above).
+
+ Raises:
+ ValueError: if TOWER_STRATEGY is not one of "separate" or "concat".
+ """
+ inputs = self._inputs
+ # inputs is a list over towers of Tensors
+ # grads_list is a list of list with the first index being sources and the
+ # second being towers.
+ if fisher_factors.TOWER_STRATEGY == "concat":
+ # Merge towers together into a PartitionedTensor. We package it in
+ # a singleton tuple since the factors will expect a list over towers
+ inputs = (utils.PartitionedTensor(inputs),)
+ # Do the same for grads_list but preserve leading sources dimension
+ grads_list = tuple((utils.PartitionedTensor(grads),)
+ for grads in grads_list)
+ elif fisher_factors.TOWER_STRATEGY == "separate":
+ inputs = tuple(inputs)
+ grads_list = tuple(grads_list)
+
+ else:
+ raise ValueError("Global config variable TOWER_STRATEGY must be one of "
+ "'concat' or 'separate'.")
+
+ return inputs, grads_list
+
+ def tensors_to_compute_grads(self):
+ """Tensors to compute derivative of loss with respect to."""
+ return tuple(self._outputs)
+
+ def register_additional_tower(self, inputs, outputs):
+ self._inputs.append(inputs)
+ self._outputs.append(outputs)
+
+ @property
+ def num_registered_towers(self):
+ result = len(self._inputs)
+ assert result == len(self._outputs)
+ return result
+
+ @property
+ def _inputs(self):
+ return self.__inputs
+
+ @property
+ def _outputs(self):
+ return self.__outputs
+
+
+class FullyConnectedDiagonalFB(InputOutputMultiTower, DiagonalFB):
+ """FisherBlock for fully-connected (dense) layers using a diagonal approx.
+
+ Estimates the Fisher Information matrix's diagonal entries for a fully
+ connected layer. Unlike NaiveDiagonalFB this uses the low-variance "sum of
+ squares" estimator.
+
+ Let 'params' be a vector parameterizing a model and 'i' an arbitrary index
+ into it. We are interested in Fisher(params)[i, i]. This is,
+
+ $$Fisher(params)[i, i] = E[ v(x, y, params) v(x, y, params)^T ][i, i]
+ = E[ v(x, y, params)[i] ^ 2 ]$$
+
+ Consider fully connected layer in this model with (unshared) weight matrix
+ 'w'. For an example 'x' that produces layer inputs 'a' and output
+ preactivations 's',
+
+ $$v(x, y, w) = vec( a (d loss / d s)^T )$$
+
+ This FisherBlock tracks Fisher(params)[i, i] for all indices 'i' corresponding
+ to the layer's parameters 'w'.
+ """
+
+ def __init__(self, layer_collection, has_bias=False):
+ """Creates a FullyConnectedDiagonalFB block.
+
+ Args:
+ layer_collection: The collection of all layers in the K-FAC approximate
+ Fisher information matrix to which this FisherBlock belongs.
+ has_bias: Whether the component Kronecker factors have an additive bias.
+ (Default: False)
+ """
+ self._has_bias = has_bias
+
+ super(FullyConnectedDiagonalFB, self).__init__(layer_collection)
+
+ def instantiate_factors(self, grads_list, damping):
+ inputs, grads_list = self._process_data(grads_list)
+
+ self._factor = self._layer_collection.make_or_get_factor(
+ fisher_factors.FullyConnectedDiagonalFactor,
+ (inputs, grads_list, self._has_bias))
+
+ self._damping_func = _package_func(lambda: damping, (damping,))
+
+
+class ConvDiagonalFB(InputOutputMultiTower, DiagonalFB):
+ """FisherBlock for 2-D convolutional layers using a diagonal approx.
+
+ Estimates the Fisher Information matrix's diagonal entries for a convolutional
+ layer. Unlike NaiveDiagonalFB this uses the low-variance "sum of squares"
+ estimator.
+
+ Let 'params' be a vector parameterizing a model and 'i' an arbitrary index
+ into it. We are interested in Fisher(params)[i, i]. This is,
+
+ $$Fisher(params)[i, i] = E[ v(x, y, params) v(x, y, params)^T ][i, i]
+ = E[ v(x, y, params)[i] ^ 2 ]$$
+
+ Consider a convoluational layer in this model with (unshared) filter matrix
+ 'w'. For an example image 'x' that produces layer inputs 'a' and output
+ preactivations 's',
+
+ $$v(x, y, w) = vec( sum_{loc} a_{loc} (d loss / d s_{loc})^T )$$
+
+ where 'loc' is a single (x, y) location in an image.
+
+ This FisherBlock tracks Fisher(params)[i, i] for all indices 'i' corresponding
+ to the layer's parameters 'w'.
+ """
+
+ def __init__(self,
+ layer_collection,
+ params,
+ strides,
+ padding,
+ data_format=None,
+ dilations=None):
+ """Creates a ConvDiagonalFB block.
+
+ Args:
+ layer_collection: The collection of all layers in the K-FAC approximate
+ Fisher information matrix to which this FisherBlock belongs.
+ params: The parameters (Tensor or tuple of Tensors) of this layer. If
+ kernel alone, a Tensor of shape [kernel_height, kernel_width,
+ in_channels, out_channels]. If kernel and bias, a tuple of 2 elements
+ containing the previous and a Tensor of shape [out_channels].
+ strides: The stride size in this layer (1-D Tensor of length 4).
+ padding: The padding in this layer (e.g. "SAME").
+ data_format: str or None. Format of input data.
+ dilations: List of 4 ints or None. Rate for dilation along all dimensions.
+
+ Raises:
+ ValueError: if strides is not length-4.
+ ValueError: if dilations is not length-4.
+ ValueError: if channel is not last dimension.
+ """
+ if len(strides) != 4:
+ raise ValueError("strides must contain 4 numbers.")
+
+ if dilations is None:
+ dilations = [1, 1, 1, 1]
+
+ if len(dilations) != 4:
+ raise ValueError("dilations must contain 4 numbers.")
+
+ if not utils.is_data_format_channel_last(data_format):
+ raise ValueError("data_format must be channels-last.")
+
+ self._strides = maybe_tuple(strides)
+ self._padding = padding
+ self._data_format = data_format
+ self._dilations = maybe_tuple(dilations)
+ self._has_bias = isinstance(params, (tuple, list))
+
+ fltr = params[0] if self._has_bias else params
+ self._filter_shape = tuple(fltr.shape.as_list())
+
+ if len(self._filter_shape) != 4:
+ raise ValueError(
+ "Convolution filter must be of shape"
+ " [filter_height, filter_width, in_channels, out_channels].")
+
+ super(ConvDiagonalFB, self).__init__(layer_collection)
+
+ def instantiate_factors(self, grads_list, damping):
+ inputs, grads_list = self._process_data(grads_list)
+
+ # Infer number of locations upon which convolution is applied.
+ self._num_locations = num_conv_locations(inputs[0].shape.as_list(),
+ self._strides)
+
+ self._factor = self._layer_collection.make_or_get_factor(
+ fisher_factors.ConvDiagonalFactor,
+ (inputs, grads_list, self._filter_shape, self._strides, self._padding,
+ self._data_format, self._dilations, self._has_bias))
+
+ def damping_func():
+ return self._num_locations * normalize_damping(damping,
+ self._num_locations)
+
+ damping_id = (self._num_locations, "mult", "normalize_damping", damping,
+ self._num_locations)
+ self._damping_func = _package_func(damping_func, damping_id)
+
+
+class KroneckerProductFB(FisherBlock):
+ """A base class for blocks with separate input and output Kronecker factors.
+
+ The Fisher block is approximated as a Kronecker product of the input and
+ output factors.
+ """
+
+ def _setup_damping(self, damping, normalization=None):
+ """Makes functions that compute the damping values for both factors."""
+ def compute_damping():
+ if normalization is not None:
+ maybe_normalized_damping = normalize_damping(damping, normalization)
+ else:
+ maybe_normalized_damping = damping
+
+ return compute_pi_adjusted_damping(
+ self._input_factor.get_cov_as_linear_operator(),
+ self._output_factor.get_cov_as_linear_operator(),
+ maybe_normalized_damping**0.5)
+
+ if normalization is not None:
+ damping_id = ("compute_pi_adjusted_damping",
+ "cov", self._input_factor.name,
+ "cov", self._output_factor.name,
+ "normalize_damping", damping, normalization, "power", 0.5)
+ else:
+ damping_id = ("compute_pi_adjusted_damping",
+ "cov", self._input_factor.name,
+ "cov", self._output_factor.name,
+ damping, "power", 0.5)
+
+ self._input_damping_func = _package_func(lambda: compute_damping()[0],
+ damping_id + ("ref", 0))
+ self._output_damping_func = _package_func(lambda: compute_damping()[1],
+ damping_id + ("ref", 1))
+
+ def register_matpower(self, exp):
+ self._input_factor.register_matpower(exp, self._input_damping_func)
+ self._output_factor.register_matpower(exp, self._output_damping_func)
+
+ def register_cholesky(self):
+ self._input_factor.register_cholesky(self._input_damping_func)
+ self._output_factor.register_cholesky(self._output_damping_func)
+
+ def register_cholesky_inverse(self):
+ self._input_factor.register_cholesky_inverse(self._input_damping_func)
+ self._output_factor.register_cholesky_inverse(self._output_damping_func)
+
+ @property
+ def _renorm_coeff(self):
+ """Kronecker factor multiplier coefficient.
+
+ If this FisherBlock is represented as 'FB = c * kron(left, right)', then
+ this is 'c'.
+
+ Returns:
+ 0-D Tensor.
+ """
+ return 1.0
+
+ def _multiply_factored_matrix(self, left_factor, right_factor, vector,
+ extra_scale=1.0, transpose_left=False,
+ transpose_right=False):
+ reshaped_vector = utils.layer_params_to_mat2d(vector)
+ reshaped_out = right_factor.matmul_right(reshaped_vector,
+ adjoint=transpose_right)
+ reshaped_out = left_factor.matmul(reshaped_out,
+ adjoint=transpose_left)
+ if extra_scale != 1.0:
+ reshaped_out *= math_ops.cast(extra_scale, dtype=reshaped_out.dtype)
+ return utils.mat2d_to_layer_params(vector, reshaped_out)
+
+ def multiply_matpower(self, vector, exp):
+ left_factor = self._input_factor.get_matpower(
+ exp, self._input_damping_func)
+ right_factor = self._output_factor.get_matpower(
+ exp, self._output_damping_func)
+ extra_scale = float(self._renorm_coeff)**exp
+ return self._multiply_factored_matrix(left_factor, right_factor, vector,
+ extra_scale=extra_scale)
+
+ def multiply_cholesky(self, vector, transpose=False):
+ left_factor = self._input_factor.get_cholesky(self._input_damping_func)
+ right_factor = self._output_factor.get_cholesky(self._output_damping_func)
+ extra_scale = float(self._renorm_coeff)**0.5
+ return self._multiply_factored_matrix(left_factor, right_factor, vector,
+ extra_scale=extra_scale,
+ transpose_left=transpose,
+ transpose_right=not transpose)
+
+ def multiply_cholesky_inverse(self, vector, transpose=False):
+ left_factor = self._input_factor.get_cholesky_inverse(
+ self._input_damping_func)
+ right_factor = self._output_factor.get_cholesky_inverse(
+ self._output_damping_func)
+ extra_scale = float(self._renorm_coeff)**-0.5
+ return self._multiply_factored_matrix(left_factor, right_factor, vector,
+ extra_scale=extra_scale,
+ transpose_left=transpose,
+ transpose_right=not transpose)
+
+ def full_fisher_block(self):
+ """Explicitly constructs the full Fisher block.
+
+ Used for testing purposes. (In general, the result may be very large.)
+
+ Returns:
+ The full Fisher block.
+ """
+ left_factor = self._input_factor.get_cov_as_linear_operator().to_dense()
+ right_factor = self._output_factor.get_cov_as_linear_operator().to_dense()
+ return self._renorm_coeff * utils.kronecker_product(left_factor,
+ right_factor)
+
+
+class EmbeddingKFACFB(InputOutputMultiTower, KroneckerProductFB):
+ """K-FAC FisherBlock for embedding layers.
+
+ This FisherBlock is similar to FullyConnectedKFACBasicFB, except that its
+ input factor is approximated by a diagonal matrix. In the case that each
+ example references exactly one embedding, this approximation is exact.
+
+ Does not support bias parameters.
+ """
+
+ def __init__(self, layer_collection, vocab_size):
+ """Creates a EmbeddingKFACFB block.
+
+ Args:
+ layer_collection: The collection of all layers in the K-FAC approximate
+ Fisher information matrix to which this FisherBlock belongs.
+ vocab_size: int. Size of vocabulary for this embedding layer.
+ """
+ self._vocab_size = vocab_size
+
+ super(EmbeddingKFACFB, self).__init__(layer_collection)
+
+ def instantiate_factors(self, grads_list, damping):
+ """Instantiate Kronecker Factors for this FisherBlock.
+
+ Args:
+ grads_list: List of list of Tensors. grads_list[i][j] is the
+ gradient of the loss with respect to 'outputs' from source 'i' and
+ tower 'j'. Each Tensor has shape [tower_minibatch_size, output_size].
+ damping: 0-D Tensor or float. 'damping' * identity is approximately added
+ to this FisherBlock's Fisher approximation.
+ """
+ inputs, grads_list = self._process_data(grads_list)
+
+ self._input_factor = self._layer_collection.make_or_get_factor(
+ fisher_factors.EmbeddingInputKroneckerFactor,
+ (inputs, self._vocab_size))
+ self._output_factor = self._layer_collection.make_or_get_factor(
+ fisher_factors.FullyConnectedKroneckerFactor, (grads_list,))
+ self._setup_damping(damping)
+
+
+class FullyConnectedKFACBasicFB(InputOutputMultiTower, KroneckerProductFB):
+ """K-FAC FisherBlock for fully-connected (dense) layers.
+
+ This uses the Kronecker-factorized approximation from the original
+ K-FAC paper (https://arxiv.org/abs/1503.05671)
+ """
+
+ def __init__(self, layer_collection, has_bias=False):
+ """Creates a FullyConnectedKFACBasicFB block.
+
+ Args:
+ layer_collection: The collection of all layers in the K-FAC approximate
+ Fisher information matrix to which this FisherBlock belongs.
+ has_bias: Whether the component Kronecker factors have an additive bias.
+ (Default: False)
+ """
+ self._has_bias = has_bias
+
+ super(FullyConnectedKFACBasicFB, self).__init__(layer_collection)
+
+ def instantiate_factors(self, grads_list, damping):
+ """Instantiate Kronecker Factors for this FisherBlock.
+
+ Args:
+ grads_list: List of list of Tensors. grads_list[i][j] is the
+ gradient of the loss with respect to 'outputs' from source 'i' and
+ tower 'j'. Each Tensor has shape [tower_minibatch_size, output_size].
+ damping: 0-D Tensor or float. 'damping' * identity is approximately added
+ to this FisherBlock's Fisher approximation.
+ """
+ inputs, grads_list = self._process_data(grads_list)
+
+ self._input_factor = self._layer_collection.make_or_get_factor(
+ fisher_factors.FullyConnectedKroneckerFactor,
+ ((inputs,), self._has_bias))
+ self._output_factor = self._layer_collection.make_or_get_factor(
+ fisher_factors.FullyConnectedKroneckerFactor,
+ (grads_list,))
+ self._setup_damping(damping)
+
+
+class ConvKFCBasicFB(InputOutputMultiTower, KroneckerProductFB):
+ r"""FisherBlock for convolutional layers using the basic KFC approx.
+
+ Estimates the Fisher Information matrix's blog for a convolutional
+ layer.
+
+ Consider a convolutional layer in this model with (unshared) filter matrix
+ 'w'. For a minibatch that produces inputs 'a' and output preactivations 's',
+ this FisherBlock estimates,
+
+ $$F(w) = \#locations * kronecker(E[flat(a) flat(a)^T],
+ E[flat(ds) flat(ds)^T])$$
+
+ where
+
+ $$ds = (d / ds) log p(y | x, w)$$
+ #locations = number of (x, y) locations where 'w' is applied.
+
+ where the expectation is taken over all examples and locations and flat()
+ concatenates an array's leading dimensions.
+
+ See equation 23 in https://arxiv.org/abs/1602.01407 for details.
+ """
+
+ def __init__(self,
+ layer_collection,
+ params,
+ padding,
+ strides=None,
+ dilation_rate=None,
+ data_format=None,
+ extract_patches_fn=None):
+ """Creates a ConvKFCBasicFB block.
+
+ Args:
+ layer_collection: The collection of all layers in the K-FAC approximate
+ Fisher information matrix to which this FisherBlock belongs.
+ params: The parameters (Tensor or tuple of Tensors) of this layer. If
+ kernel alone, a Tensor of shape [..spatial_filter_shape..,
+ in_channels, out_channels]. If kernel and bias, a tuple of 2 elements
+ containing the previous and a Tensor of shape [out_channels].
+ padding: str. Padding method.
+ strides: List of ints or None. Contains [..spatial_filter_strides..] if
+ 'extract_patches_fn' is compatible with tf.nn.convolution(), else
+ [1, ..spatial_filter_strides, 1].
+ dilation_rate: List of ints or None. Rate for dilation along each spatial
+ dimension if 'extract_patches_fn' is compatible with
+ tf.nn.convolution(), else [1, ..spatial_dilation_rates.., 1].
+ data_format: str or None. Format of input data.
+ extract_patches_fn: str or None. Name of function that extracts image
+ patches. One of "extract_convolution_patches", "extract_image_patches",
+ "extract_pointwise_conv2d_patches".
+ """
+ self._padding = padding
+ self._strides = maybe_tuple(strides)
+ self._dilation_rate = maybe_tuple(dilation_rate)
+ self._data_format = data_format
+ self._extract_patches_fn = extract_patches_fn
+ self._has_bias = isinstance(params, (tuple, list))
+
+ fltr = params[0] if self._has_bias else params
+ self._filter_shape = tuple(fltr.shape.as_list())
+
+ super(ConvKFCBasicFB, self).__init__(layer_collection)
+
+ def instantiate_factors(self, grads_list, damping):
+ inputs, grads_list = self._process_data(grads_list)
+
+ # Infer number of locations upon which convolution is applied.
+ self._num_locations = num_conv_locations(inputs[0].shape.as_list(),
+ self._strides)
+
+ self._input_factor = self._layer_collection.make_or_get_factor(
+ fisher_factors.ConvInputKroneckerFactor,
+ (inputs, self._filter_shape, self._padding, self._strides,
+ self._dilation_rate, self._data_format, self._extract_patches_fn,
+ self._has_bias))
+ self._output_factor = self._layer_collection.make_or_get_factor(
+ fisher_factors.ConvOutputKroneckerFactor, (grads_list,))
+
+ self._setup_damping(damping, normalization=self._num_locations)
+
+ @property
+ def _renorm_coeff(self):
+ return self._num_locations
+
+
+class DepthwiseConvDiagonalFB(ConvDiagonalFB):
+ """FisherBlock for depthwise_conv2d().
+
+ Equivalent to ConvDiagonalFB applied to each input channel in isolation.
+ """
+
+ def __init__(self,
+ layer_collection,
+ params,
+ strides,
+ padding,
+ rate=None,
+ data_format=None):
+ """Creates a DepthwiseConvKFCBasicFB block.
+
+ Args:
+ layer_collection: The collection of all layers in the K-FAC approximate
+ Fisher information matrix to which this FisherBlock belongs.
+ params: Tensor of shape [filter_height, filter_width, in_channels,
+ channel_multiplier].
+ strides: List of 4 ints. Strides along all dimensions.
+ padding: str. Padding method.
+ rate: List of 4 ints or None. Rate for dilation along all dimensions.
+ data_format: str or None. Format of input data.
+
+ Raises:
+ NotImplementedError: If parameters contains bias.
+ ValueError: If filter is not 4-D.
+ ValueError: If strides is not length-4.
+ ValueError: If rates is not length-2.
+ ValueError: If channels are not last dimension.
+ """
+ if isinstance(params, (tuple, list)):
+ raise NotImplementedError("Bias not yet supported.")
+
+ if params.shape.ndims != 4:
+ raise ValueError("Filter must be 4-D.")
+
+ if len(strides) != 4:
+ raise ValueError("strides must account for 4 dimensions.")
+
+ if rate is not None:
+ if len(rate) != 2:
+ raise ValueError("rate must only account for spatial dimensions.")
+ rate = [1, rate[0], rate[1], 1] # conv2d expects 4-element rate.
+
+ if not utils.is_data_format_channel_last(data_format):
+ raise ValueError("data_format must be channels-last.")
+
+ super(DepthwiseConvDiagonalFB, self).__init__(
+ layer_collection=layer_collection,
+ params=params,
+ strides=strides,
+ padding=padding,
+ dilations=rate,
+ data_format=data_format)
+
+ # This is a hack to overwrite the same setting in ConvKFCBasicFB.__init__().
+ filter_height, filter_width, in_channels, channel_multiplier = (
+ params.shape.as_list())
+ self._filter_shape = (filter_height, filter_width, in_channels,
+ in_channels * channel_multiplier)
+
+ def _multiply_matrix(self, matrix, vector):
+ conv2d_vector = depthwise_conv2d_filter_to_conv2d_filter(vector)
+ conv2d_result = super(
+ DepthwiseConvDiagonalFB, self)._multiply_matrix(matrix, conv2d_vector)
+ return conv2d_filter_to_depthwise_conv2d_filter(conv2d_result)
+
+
+class DepthwiseConvKFCBasicFB(ConvKFCBasicFB):
+ """FisherBlock for depthwise_conv2d().
+
+ Equivalent to ConvKFCBasicFB applied to each input channel in isolation.
+ """
+
+ def __init__(self,
+ layer_collection,
+ params,
+ strides,
+ padding,
+ rate=None,
+ data_format=None):
+ """Creates a DepthwiseConvKFCBasicFB block.
+
+ Args:
+ layer_collection: The collection of all layers in the K-FAC approximate
+ Fisher information matrix to which this FisherBlock belongs.
+ params: Tensor of shape [filter_height, filter_width, in_channels,
+ channel_multiplier].
+ strides: List of 4 ints. Strides along all dimensions.
+ padding: str. Padding method.
+ rate: List of 4 ints or None. Rate for dilation along all dimensions.
+ data_format: str or None. Format of input data.
+
+ Raises:
+ NotImplementedError: If parameters contains bias.
+ ValueError: If filter is not 4-D.
+ ValueError: If strides is not length-4.
+ ValueError: If rates is not length-2.
+ ValueError: If channels are not last dimension.
+ """
+ if isinstance(params, (tuple, list)):
+ raise NotImplementedError("Bias not yet supported.")
+
+ if params.shape.ndims != 4:
+ raise ValueError("Filter must be 4-D.")
+
+ if len(strides) != 4:
+ raise ValueError("strides must account for 4 dimensions.")
+
+ if rate is not None:
+ if len(rate) != 2:
+ raise ValueError("rate must only account for spatial dimensions.")
+ rate = [1, rate[0], rate[1], 1] # conv2d expects 4-element rate.
+
+ if not utils.is_data_format_channel_last(data_format):
+ raise ValueError("data_format must be channels-last.")
+
+ super(DepthwiseConvKFCBasicFB, self).__init__(
+ layer_collection=layer_collection,
+ params=params,
+ padding=padding,
+ strides=strides,
+ dilation_rate=rate,
+ data_format=data_format,
+ extract_patches_fn="extract_image_patches")
+
+ # This is a hack to overwrite the same setting in ConvKFCBasicFB.__init__().
+ filter_height, filter_width, in_channels, channel_multiplier = (
+ params.shape.as_list())
+ self._filter_shape = (filter_height, filter_width, in_channels,
+ in_channels * channel_multiplier)
+
+ def _multiply_factored_matrix(self, left_factor, right_factor, vector,
+ extra_scale=1.0, transpose_left=False,
+ transpose_right=False):
+ conv2d_vector = depthwise_conv2d_filter_to_conv2d_filter(vector)
+ conv2d_result = super(
+ DepthwiseConvKFCBasicFB, self)._multiply_factored_matrix(
+ left_factor, right_factor, conv2d_vector, extra_scale=extra_scale,
+ transpose_left=transpose_left, transpose_right=transpose_right)
+ return conv2d_filter_to_depthwise_conv2d_filter(conv2d_result)
+
+
+def depthwise_conv2d_filter_to_conv2d_filter(filter, name=None): # pylint: disable=redefined-builtin
+ """Converts a convolution filter for use with conv2d.
+
+ Transforms a filter for use with tf.nn.depthwise_conv2d() to one that's
+ compatible with tf.nn.conv2d().
+
+ Args:
+ filter: Tensor of shape [height, width, in_channels, channel_multiplier].
+ name: None or str. Name of Op.
+
+ Returns:
+ Tensor of shape [height, width, in_channels, out_channels].
+
+ """
+ with ops.name_scope(name, "depthwise_conv2d_filter_to_conv2d_filter",
+ [filter]):
+ filter = ops.convert_to_tensor(filter)
+ filter_height, filter_width, in_channels, channel_multiplier = (
+ filter.shape.as_list())
+
+ results = []
+ for i in range(in_channels):
+ # Slice out one in_channel's filter. Insert zeros around it to force it
+ # to affect that channel and that channel alone.
+ elements = []
+ if i > 0:
+ elements.append(
+ array_ops.zeros(
+ [filter_height, filter_width, i, channel_multiplier]))
+ elements.append(filter[:, :, i:(i + 1), :])
+ if i + 1 < in_channels:
+ elements.append(
+ array_ops.zeros([
+ filter_height, filter_width, in_channels - (i + 1),
+ channel_multiplier
+ ]))
+
+ # Concat along in_channel.
+ results.append(
+ array_ops.concat(elements, axis=-2, name="in_channel_%d" % i))
+
+ # Concat along out_channel.
+ return array_ops.concat(results, axis=-1, name="out_channel")
+
+
+def conv2d_filter_to_depthwise_conv2d_filter(filter, name=None): # pylint: disable=redefined-builtin
+ """Converts a convolution filter for use with depthwise_conv2d.
+
+ Transforms a filter for use with tf.nn.conv2d() to one that's
+ compatible with tf.nn.depthwise_conv2d(). Ignores all filters but those along
+ the diagonal.
+
+ Args:
+ filter: Tensor of shape [height, width, in_channels, out_channels].
+ name: None or str. Name of Op.
+
+ Returns:
+ Tensor of shape,
+ [height, width, in_channels, channel_multiplier]
+
+ Raises:
+ ValueError: if out_channels is not evenly divisible by in_channels.
+ """
+ with ops.name_scope(name, "conv2d_filter_to_depthwise_conv2d_filter",
+ [filter]):
+ filter = ops.convert_to_tensor(filter)
+ filter_height, filter_width, in_channels, out_channels = (
+ filter.shape.as_list())
+
+ if out_channels % in_channels != 0:
+ raise ValueError("out_channels must be evenly divisible by in_channels.")
+ channel_multiplier = out_channels // in_channels
+
+ results = []
+ filter = array_ops.reshape(filter, [
+ filter_height, filter_width, in_channels, in_channels,
+ channel_multiplier
+ ])
+ for i in range(in_channels):
+ # Slice out output corresponding to the correct filter.
+ filter_slice = array_ops.reshape(
+ filter[:, :, i, i, :],
+ [filter_height, filter_width, 1, channel_multiplier])
+ results.append(filter_slice)
+
+ # Concat along out_channel.
+ return array_ops.concat(results, axis=-2, name="in_channels")
+
+
+def maybe_tuple(obj):
+ if not isinstance(obj, list):
+ return obj
+ return tuple(obj)
+
+
+def num_conv_locations(input_shape, strides):
+ """Returns the number of spatial locations a 2D Conv kernel is applied to.
+
+ Args:
+ input_shape: List of ints representing shape of inputs to
+ tf.nn.convolution().
+ strides: List of ints representing strides along spatial dimensions as
+ passed in to tf.nn.convolution().
+
+ Returns:
+ A scalar |T| denoting the number of spatial locations for the Conv layer.
+ """
+ spatial_input_locations = np.prod(input_shape[1:-1])
+
+ if strides is None:
+ spatial_strides_divisor = 1
+ else:
+ spatial_strides_divisor = np.prod(strides)
+
+ return spatial_input_locations // spatial_strides_divisor
+
+
+class InputOutputMultiTowerMultiUse(InputOutputMultiTower):
+ """Adds methods for multi-use/time-step case to InputOutputMultiTower."""
+
+ def __init__(self, num_uses=None, *args, **kwargs):
+ self._num_uses = num_uses
+ super(InputOutputMultiTowerMultiUse, self).__init__(*args, **kwargs)
+
+ def _process_data(self, grads_list):
+ """Process temporal/multi-use data into the format used by the factors.
+
+ This function takes inputs and grads_lists data and processes it into
+ one of the formats expected by the FisherFactor classes (depending on
+ the value of the global configuration variable TOWER_STRATEGY).
+
+ It accepts the data in one of two initial formats. The first possible
+ format is where self._inputs is a list of list of Tensors. The first index
+ is tower, the second is use/time-step. grads_list, meanwhile, is a list
+ over sources of such lists of lists.
+
+ The second possible data format is where self._inputs is a Tensor with
+ uses/times-steps folded into the batch dimension. i.e. it is a Tensor
+ of shape [num_uses * size_batch, ...] which represents a reshape of a
+ Tensor of shape [num_uses, size_batch, ...]. And similarly grads_list is
+ a list over sources of such Tensors.
+
+ There are two possible formats which inputs and grads_list are transformed
+ into.
+
+ If TOWER_STRATEGY is "concat", 'inputs' becomes a tuple containing
+ a single tensor (represented as a PartitionedTensor object) with all of
+ the data from the towers, as well as the uses/time-steps, concatenated
+ together. In this tensor the leading dimension is the batch and
+ use/time-step dimensions folded together (with 'use' being the major of
+ these two, so that the tensors can be thought of as reshapes of ones of
+ shape [num_uses, batch_size, ...]). grads_list is similarly formatted as a
+ tuple over sources of such tensors.
+
+ If TOWER_STRATEGY is "separate" the inputs are formatted into lists of
+ tensors over towers. Each of these tensors has a similar format to
+ the tensor produced by the "concat" option, except that each contains
+ only the data from a single tower. grads_list is similarly formatted
+ into a tuple over sources of such tuples.
+
+ Args:
+ grads_list: grads_list in its initial format (see above).
+
+ Returns:
+ inputs: self._inputs transformed into the appropriate format (see
+ above).
+ grads_list: grads_list transformed into the appropriate format (see
+ above).
+
+ Raises:
+ ValueError: If TOWER_STRATEGY is not one of "separate" or "concat".
+ ValueError: If the given/initial format of self._inputs and grads_list
+ isn't recognized, or doesn't agree with self._num_uses.
+ """
+
+ inputs = self._inputs
+
+ if isinstance(inputs[0], (list, tuple)):
+ num_uses = len(inputs[0])
+ if self._num_uses is not None and self._num_uses != num_uses:
+ raise ValueError("num_uses argument doesn't match length of inputs.")
+ else:
+ self._num_uses = num_uses
+
+ # Check that all mini-batches/towers have the same number of uses
+ if not all(len(input_) == num_uses for input_ in inputs):
+ raise ValueError("Length of inputs argument is inconsistent across "
+ "towers.")
+
+ if fisher_factors.TOWER_STRATEGY == "concat":
+ # Reverse the tower and use/time-step indices, so that use is now first,
+ # and towers is second
+ inputs = tuple(zip(*inputs))
+
+ # Flatten the two dimensions
+ inputs = nest.flatten(inputs)
+
+ # Merge everything together into a PartitionedTensor. We package it in
+ # a singleton tuple since the factors will expect a list over towers
+ inputs = (utils.PartitionedTensor(inputs),)
+
+ elif fisher_factors.TOWER_STRATEGY == "separate":
+ # Merge together the uses/time-step dimension into PartitionedTensors,
+ # but keep the leading dimension (towers) intact for the factors to
+ # process individually.
+ inputs = tuple(utils.PartitionedTensor(input_) for input_ in inputs)
+
+ else:
+ raise ValueError("Global config variable TOWER_STRATEGY must be one of "
+ "'concat' or 'separate'.")
+ else:
+ inputs = tuple(inputs)
+
+ # Now we perform the analogous processing for grads_list
+ if isinstance(grads_list[0][0], (list, tuple)):
+ num_uses = len(grads_list[0][0])
+ if self._num_uses is not None and self._num_uses != num_uses:
+ raise ValueError("num_uses argument doesn't match length of outputs, "
+ "or length of outputs is inconsistent with length of "
+ "inputs.")
+ else:
+ self._num_uses = num_uses
+
+ if not all(len(grad) == num_uses for grads in grads_list
+ for grad in grads):
+ raise ValueError("Length of outputs argument is inconsistent across "
+ "towers.")
+
+ if fisher_factors.TOWER_STRATEGY == "concat":
+ # Reverse the tower and use/time-step indices, so that use is now first,
+ # and towers is second
+ grads_list = tuple(tuple(zip(*grads)) for grads in grads_list)
+
+ # Flatten the two dimensions, leaving the leading dimension (source)
+ # intact
+ grads_list = tuple(nest.flatten(grads) for grads in grads_list)
+
+ # Merge inner dimensions together into PartitionedTensors. We package
+ # them in a singleton tuple since the factors will expect a list over
+ # towers
+ grads_list = tuple((utils.PartitionedTensor(grads),)
+ for grads in grads_list)
+
+ elif fisher_factors.TOWER_STRATEGY == "separate":
+ # Merge together the uses/time-step dimension into PartitionedTensors,
+ # but keep the leading dimension (towers) intact for the factors to
+ # process individually.
+ grads_list = tuple(tuple(utils.PartitionedTensor(grad)
+ for grad in grads)
+ for grads in grads_list)
+
+ else:
+ raise ValueError("Global config variable TOWER_STRATEGY must be one of "
+ "'concat' or 'separate'.")
+ else:
+ grads_list = tuple(tuple(grads) for grads in grads_list)
+
+ if self._num_uses is None:
+ raise ValueError("You must supply a value for the num_uses argument if "
+ "the number of uses cannot be inferred from inputs or "
+ "outputs arguments (e.g. if they are both given in the "
+ "single Tensor format, instead of as lists of Tensors.")
+
+ return inputs, grads_list
+
+
+class FullyConnectedMultiIndepFB(InputOutputMultiTowerMultiUse,
+ KroneckerProductFB):
+ """FisherBlock for fully-connected layers that share parameters.
+
+ This class implements the "independence across time" approximation from the
+ following paper:
+ https://openreview.net/pdf?id=HyMTkQZAb
+ """
+
+ def __init__(self, layer_collection, has_bias=False, num_uses=None):
+ """Creates a FullyConnectedMultiIndepFB block.
+
+ Args:
+ layer_collection: LayerCollection instance.
+ has_bias: bool. If True, estimates Fisher with respect to a bias
+ parameter as well as the layer's parameters.
+ num_uses: int or None. Number of uses of the layer in the model's graph.
+ Only required if the data is formatted with uses/time folded into the
+ batch dimension (instead of uses/time being a list dimension).
+ (Default: None)
+ """
+ self._has_bias = has_bias
+
+ super(FullyConnectedMultiIndepFB, self).__init__(
+ layer_collection=layer_collection,
+ num_uses=num_uses)
+
+ def instantiate_factors(self, grads_list, damping):
+ inputs, grads_list = self._process_data(grads_list)
+
+ self._input_factor = self._layer_collection.make_or_get_factor(
+ fisher_factors.FullyConnectedMultiKF,
+ ((inputs,), self._num_uses, self._has_bias))
+
+ self._output_factor = self._layer_collection.make_or_get_factor(
+ fisher_factors.FullyConnectedMultiKF, (grads_list, self._num_uses))
+
+ self._setup_damping(damping, normalization=self._num_uses)
+
+ @property
+ def _renorm_coeff(self):
+ return float(self._num_uses)
+
+
+class ConvKFCBasicMultiIndepFB(InputOutputMultiTowerMultiUse,
+ KroneckerProductFB):
+ """FisherBlock for 2D convolutional layers using the basic KFC approx.
+
+ Similar to ConvKFCBasicFB except that this version supports multiple
+ uses/time-steps via a standard independence approximation. Similar to the
+ "independence across time" used in FullyConnectedMultiIndepFB but generalized
+ in the obvious way to conv layers.
+ """
+
+ def __init__(self,
+ layer_collection,
+ params,
+ padding,
+ strides=None,
+ dilation_rate=None,
+ data_format=None,
+ extract_patches_fn=None,
+ num_uses=None):
+ """Creates a ConvKFCBasicMultiIndepFB block.
+
+ Args:
+ layer_collection: The collection of all layers in the K-FAC approximate
+ Fisher information matrix to which this FisherBlock belongs.
+ params: The parameters (Tensor or tuple of Tensors) of this layer. If
+ kernel alone, a Tensor of shape [..spatial_filter_shape..,
+ in_channels, out_channels]. If kernel and bias, a tuple of 2 elements
+ containing the previous and a Tensor of shape [out_channels].
+ padding: str. Padding method.
+ strides: List of ints or None. Contains [..spatial_filter_strides..] if
+ 'extract_patches_fn' is compatible with tf.nn.convolution(), else
+ [1, ..spatial_filter_strides, 1].
+ dilation_rate: List of ints or None. Rate for dilation along each spatial
+ dimension if 'extract_patches_fn' is compatible with
+ tf.nn.convolution(), else [1, ..spatial_dilation_rates.., 1].
+ data_format: str or None. Format of input data.
+ extract_patches_fn: str or None. Name of function that extracts image
+ patches. One of "extract_convolution_patches", "extract_image_patches",
+ "extract_pointwise_conv2d_patches".
+ num_uses: int or None. Number of uses of the layer in the model's graph.
+ Only required if the data is formatted with uses/time folded into the
+ batch dimension (instead of uses/time being a list dimension).
+ (Default: None)
+ """
+ self._padding = padding
+ self._strides = maybe_tuple(strides)
+ self._dilation_rate = maybe_tuple(dilation_rate)
+ self._data_format = data_format
+ self._extract_patches_fn = extract_patches_fn
+ self._has_bias = isinstance(params, (tuple, list))
+
+ fltr = params[0] if self._has_bias else params
+ self._filter_shape = tuple(fltr.shape.as_list())
+
+ super(ConvKFCBasicMultiIndepFB, self).__init__(
+ layer_collection=layer_collection,
+ num_uses=num_uses)
+
+ def instantiate_factors(self, grads_list, damping):
+ inputs, grads_list = self._process_data(grads_list)
+
+ # Infer number of locations upon which convolution is applied.
+ self._num_locations = num_conv_locations(inputs[0].shape.as_list(),
+ self._strides)
+
+ self._input_factor = self._layer_collection.make_or_get_factor(
+ fisher_factors.ConvInputKroneckerFactor,
+ (inputs, self._filter_shape, self._padding, self._strides,
+ self._dilation_rate, self._data_format, self._extract_patches_fn,
+ self._has_bias))
+ self._output_factor = self._layer_collection.make_or_get_factor(
+ fisher_factors.ConvOutputKroneckerFactor, (grads_list,))
+
+ self._setup_damping(damping, normalization=
+ (self._num_locations * self._num_uses))
+
+ @property
+ def _renorm_coeff(self):
+ return self._num_locations * self._num_uses
+
+
+class EmbeddingKFACMultiIndepFB(InputOutputMultiTowerMultiUse,
+ KroneckerProductFB):
+ """K-FAC FisherBlock for embedding layers used multiple times in the graph.
+
+ Similar to EmbeddingKFACFB except that this version supports multiple uses
+ of the parameter within a single model. These uses could correspond to time
+ steps in an RNN architecture, but they don't have to.
+
+ Does not support bias parameters.
+ """
+
+ def __init__(self, layer_collection, vocab_size, num_uses=None):
+ """Creates a EmbeddingKFACMultiIndepFB block.
+
+ Args:
+ layer_collection: The collection of all layers in the K-FAC approximate
+ Fisher information matrix to which this FisherBlock belongs.
+ vocab_size: int. Size of vocabulary for this embedding layer.
+ num_uses: int or None. Number of uses of the layer in the model's graph.
+ Only required if the data is formatted with time folded into the batch
+ dimension (instead of time being a list dimension). (Default: None)
+ """
+ self._vocab_size = vocab_size
+
+ super(EmbeddingKFACMultiIndepFB, self).__init__(
+ layer_collection=layer_collection,
+ num_uses=num_uses)
+
+ def instantiate_factors(self, grads_list, damping):
+ """Instantiate Kronecker Factors for this FisherBlock.
+
+ Args:
+ grads_list: List of list of list of Tensors. grads_list[i][j][k] is the
+ gradient of the loss with respect to 'outputs' from source 'i',
+ tower/mini-batch 'j', and use/time-step 'k'. Each Tensor has shape
+ [tower_minibatch_size, output_size].
+ damping: 0-D Tensor or float. 'damping' * identity is approximately added
+ to this FisherBlock's Fisher approximation.
+ """
+ inputs, grads_list = self._process_data(grads_list)
+
+ self._input_factor = self._layer_collection.make_or_get_factor(
+ fisher_factors.EmbeddingInputKroneckerFactor,
+ (inputs, self._vocab_size))
+ self._output_factor = self._layer_collection.make_or_get_factor(
+ fisher_factors.FullyConnectedMultiKF, (grads_list, self._num_uses))
+ self._setup_damping(damping, normalization=self._num_uses)
+
+ @property
+ def _renorm_coeff(self):
+ return float(self._num_uses)
+
+
+class SeriesFBApproximation(enum.IntEnum):
+ """See FullyConnectedSeriesFB.__init__ for description and usage."""
+ option1 = 1
+ option2 = 2
+
+
+class FullyConnectedSeriesFB(InputOutputMultiTowerMultiUse,
+ KroneckerProductFB):
+ """FisherBlock for fully-connected layers that share parameters across time.
+
+ This class implements the "Option 1" and "Option 2" approximation from the
+ following paper:
+ https://openreview.net/pdf?id=HyMTkQZAb
+
+ See the end of the appendix of the paper for a pseudo-code of the
+ algorithm being implemented by multiply_matpower here. Note that we are
+ using pre-computed versions of certain matrix-matrix products to speed
+ things up. This is explicitly explained wherever it is done.
+ """
+
+ def __init__(self,
+ layer_collection,
+ has_bias=False,
+ num_uses=None,
+ option=SeriesFBApproximation.option2):
+ """Constructs a new `FullyConnectedSeriesFB`.
+
+ Args:
+ layer_collection: The collection of all layers in the K-FAC approximate
+ Fisher information matrix to which this FisherBlock belongs.
+ has_bias: Whether the layer includes a bias parameter.
+ num_uses: int or None. Number of time-steps over which the layer
+ is used. Only required if the data is formatted with time folded into
+ the batch dimension (instead of time being a list dimension).
+ (Default: None)
+ option: A `SeriesFBApproximation` specifying the simplifying assumption
+ to be used in this block. `option1` approximates the cross-covariance
+ over time as a symmetric matrix, while `option2` makes
+ the assumption that training sequences are infinitely long. See section
+ 3.5 of the paper for more details.
+ """
+
+ self._has_bias = has_bias
+ self._option = option
+
+ super(FullyConnectedSeriesFB, self).__init__(
+ layer_collection=layer_collection,
+ num_uses=num_uses)
+
+ @property
+ def _num_timesteps(self):
+ return self._num_uses
+
+ @property
+ def _renorm_coeff(self):
+ # This should no longer be used since the multiply_X functions from the base
+ # class have been overridden
+ assert False
+
+ def instantiate_factors(self, grads_list, damping):
+ inputs, grads_list = self._process_data(grads_list)
+
+ self._input_factor = self._layer_collection.make_or_get_factor(
+ fisher_factors.FullyConnectedMultiKF,
+ ((inputs,), self._num_uses, self._has_bias))
+ self._input_factor.register_cov_dt1()
+
+ self._output_factor = self._layer_collection.make_or_get_factor(
+ fisher_factors.FullyConnectedMultiKF, (grads_list, self._num_uses))
+ self._output_factor.register_cov_dt1()
+
+ self._setup_damping(damping, normalization=self._num_uses)
+
+ def register_matpower(self, exp):
+ if exp != -1:
+ raise NotImplementedError("FullyConnectedSeriesFB only supports inverse"
+ "multiplications.")
+
+ if self._option == SeriesFBApproximation.option1:
+ self._input_factor.register_option1quants(self._input_damping_func)
+ self._output_factor.register_option1quants(self._output_damping_func)
+ elif self._option == SeriesFBApproximation.option2:
+ self._input_factor.register_option2quants(self._input_damping_func)
+ self._output_factor.register_option2quants(self._output_damping_func)
+ else:
+ raise ValueError(
+ "Unrecognized FullyConnectedSeriesFB approximation: {}".format(
+ self._option))
+
+ def multiply_matpower(self, vector, exp):
+ if exp != -1:
+ raise NotImplementedError("FullyConnectedSeriesFB only supports inverse"
+ "multiplications.")
+
+ # pylint: disable=invalid-name
+
+ Z = utils.layer_params_to_mat2d(vector)
+
+ # Derivations were done for "batch_dim==1" case so we need to convert to
+ # that orientation:
+ Z = array_ops.transpose(Z)
+
+ if self._option == SeriesFBApproximation.option1:
+
+ # Note that \\(L_A = A0^{-1/2} * U_A and L_G = G0^{-1/2} * U_G.\\)
+ L_A, psi_A = self._input_factor.get_option1quants(
+ self._input_damping_func)
+ L_G, psi_G = self._output_factor.get_option1quants(
+ self._output_damping_func)
+
+ def gamma(x):
+ # We are assuming that each case has the same number of time-steps.
+ # If this stops being the case one shouldn't simply replace this T
+ # with its average value. Instead, one needs to go back to the
+ # definition of the gamma function from the paper.
+ T = self._num_timesteps
+ return (1 - x)**2 / (T * (1 - x**2) - 2 * x * (1 - x**T))
+
+ # \\(Y = \gamma( psi_G*psi_A^T )\\) (computed element-wise)
+ # Even though Y is Z-independent we are recomputing it from the psi's
+ # each since Y depends on both A and G quantities, and it is relatively
+ # cheap to compute.
+ Y = gamma(array_ops.reshape(psi_G, [int(psi_G.shape[0]), -1]) * psi_A)
+
+ # \\(Z = L_G^T * Z * L_A\\)
+ # This is equivalent to the following computation from the original
+ # pseudo-code:
+ # \\(Z = G0^{-1/2} * Z * A0^{-1/2}\\)
+ # \\(Z = U_G^T * Z * U_A\\)
+ Z = math_ops.matmul(L_G, math_ops.matmul(Z, L_A), transpose_a=True)
+
+ # \\(Z = Z .* Y\\)
+ Z *= Y
+
+ # \\(Z = L_G * Z * L_A^T\\)
+ # This is equivalent to the following computation from the original
+ # pseudo-code:
+ # \\(Z = U_G * Z * U_A^T\\)
+ # \\(Z = G0^{-1/2} * Z * A0^{-1/2}\\)
+ Z = math_ops.matmul(L_G, math_ops.matmul(Z, L_A, transpose_b=True))
+
+ elif self._option == SeriesFBApproximation.option2:
+
+ # Note that \\(P_A = A_1^T * A_0^{-1} and P_G = G_1^T * G_0^{-1}\\),
+ # and \\(K_A = A_0^{-1/2} * E_A\ and\ K_G = G_0^{-1/2} * E_G.\\)
+ P_A, K_A, mu_A = self._input_factor.get_option2quants(
+ self._input_damping_func)
+ P_G, K_G, mu_G = self._output_factor.get_option2quants(
+ self._output_damping_func)
+
+ # Our approach differs superficially from the pseudo-code in the paper
+ # in order to reduce the total number of matrix-matrix multiplies.
+ # In particular, the first three computations in the pseudo code are
+ # \\(Z = G0^{-1/2} * Z * A0^{-1/2}\\)
+ # \\(Z = Z - hPsi_G^T * Z * hPsi_A\\)
+ # \\(Z = E_G^T * Z * E_A\\)
+ # Noting that hPsi = C0^{-1/2} * C1 * C0^{-1/2}\\), so that
+ # \\(C0^{-1/2} * hPsi = C0^{-1} * C1 * C0^{-1/2} = P^T * C0^{-1/2}\\)
+ # the entire computation can be written as
+ # \\(Z = E_G^T * (G0^{-1/2} * Z * A0^{-1/2}\\)
+ # \\( - hPsi_G^T * G0^{-1/2} * Z * A0^{-1/2} * hPsi_A) * E_A\\)
+ # \\( = E_G^T * (G0^{-1/2} * Z * A0^{-1/2}\\)
+ # \\( - G0^{-1/2} * P_G * Z * P_A^T * A0^{-1/2}) * E_A\\)
+ # \\( = E_G^T * G0^{-1/2} * Z * A0^{-1/2} * E_A\\)
+ # \\( - E_G^T* G0^{-1/2} * P_G * Z * P_A^T * A0^{-1/2} * E_A\\)
+ # \\( = K_G^T * Z * K_A - K_G^T * P_G * Z * P_A^T * K_A\\)
+ # This final expression is computed by the following two lines:
+ # \\(Z = Z - P_G * Z * P_A^T\\)
+ Z -= math_ops.matmul(P_G, math_ops.matmul(Z, P_A, transpose_b=True))
+ # \\(Z = K_G^T * Z * K_A\\)
+ Z = math_ops.matmul(K_G, math_ops.matmul(Z, K_A), transpose_a=True)
+
+ # \\(Z = Z ./ (1*1^T - mu_G*mu_A^T)\\)
+ # Be careful with the outer product. We don't want to accidentally
+ # make it an inner-product instead.
+ tmp = 1.0 - array_ops.reshape(mu_G, [int(mu_G.shape[0]), -1]) * mu_A
+ # Prevent some numerical issues by setting any 0.0 eigs to 1.0
+ tmp += 1.0 * math_ops.cast(math_ops.equal(tmp, 0.0), dtype=tmp.dtype)
+ Z /= tmp
+
+ # We now perform the transpose/reverse version of the operations
+ # derived above, whose derivation from the original pseudo-code is
+ # analgous.
+ # \\(Z = K_G * Z * K_A^T\\)
+ Z = math_ops.matmul(K_G, math_ops.matmul(Z, K_A, transpose_b=True))
+
+ # \\(Z = Z - P_G^T * Z * P_A\\)
+ Z -= math_ops.matmul(P_G, math_ops.matmul(Z, P_A), transpose_a=True)
+
+ # \\(Z = normalize (1/E[T]) * Z\\)
+ # Note that this normalization is done because we compute the statistics
+ # by averaging, not summing, over time. (And the gradient is presumably
+ # summed over time, not averaged, and thus their scales are different.)
+ Z /= math_ops.cast(self._num_timesteps, Z.dtype)
+
+ # Convert back to the "batch_dim==0" orientation.
+ Z = array_ops.transpose(Z)
+
+ return utils.mat2d_to_layer_params(vector, Z)
+
+ # pylint: enable=invalid-name
+
+ def multiply_cholesky(self, vector):
+ raise NotImplementedError("FullyConnectedSeriesFB does not support "
+ "Cholesky computations.")
+
+ def multiply_cholesky_inverse(self, vector):
+ raise NotImplementedError("FullyConnectedSeriesFB does not support "
+ "Cholesky computations.")
+
diff --git a/tensorflow/contrib/kfac/python/ops/fisher_blocks_lib.py b/tensorflow/contrib/kfac/python/ops/fisher_blocks_lib.py
new file mode 100644
index 0000000000..c04cf727fa
--- /dev/null
+++ b/tensorflow/contrib/kfac/python/ops/fisher_blocks_lib.py
@@ -0,0 +1,45 @@
+# Copyright 2017 The TensorFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+"""FisherBlock definitions."""
+
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+# pylint: disable=unused-import,line-too-long,wildcard-import
+from tensorflow.contrib.kfac.python.ops.fisher_blocks import *
+from tensorflow.python.util.all_util import remove_undocumented
+# pylint: enable=unused-import,line-too-long,wildcard-import
+
+_allowed_symbols = [
+ 'FisherBlock',
+ 'FullFB',
+ 'NaiveDiagonalFB',
+ 'FullyConnectedDiagonalFB',
+ 'KroneckerProductFB',
+ 'EmbeddingKFACFB',
+ 'FullyConnectedKFACBasicFB',
+ 'ConvKFCBasicFB',
+ 'ConvDiagonalFB',
+ 'set_global_constants',
+ 'compute_pi_tracenorm',
+ 'compute_pi_adjusted_damping',
+ 'num_conv_locations',
+ 'normalize_damping',
+ 'LEFT_MULTIPLY',
+ 'RIGHT_MULTIPLY',
+]
+
+remove_undocumented(__name__, allowed_exception_list=_allowed_symbols)
diff --git a/tensorflow/contrib/kfac/python/ops/fisher_factors.py b/tensorflow/contrib/kfac/python/ops/fisher_factors.py
new file mode 100644
index 0000000000..afa2fd1ca7
--- /dev/null
+++ b/tensorflow/contrib/kfac/python/ops/fisher_factors.py
@@ -0,0 +1,1830 @@
+# Copyright 2017 The TensorFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+"""FisherFactor definitions."""
+
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+import abc
+import contextlib
+
+import numpy as np
+import six
+
+from tensorflow.contrib.kfac.python.ops import linear_operator as lo
+from tensorflow.contrib.kfac.python.ops import utils
+from tensorflow.python.framework import dtypes
+from tensorflow.python.framework import ops as tf_ops
+from tensorflow.python.ops import array_ops
+from tensorflow.python.ops import control_flow_ops
+from tensorflow.python.ops import init_ops
+from tensorflow.python.ops import linalg_ops
+from tensorflow.python.ops import math_ops
+from tensorflow.python.ops import random_ops
+from tensorflow.python.ops import special_math_ops
+from tensorflow.python.ops import variable_scope
+from tensorflow.python.ops import variables
+from tensorflow.python.training import moving_averages
+from tensorflow.python.util import nest
+
+
+# Whether to initialize covariance estimators at a zero matrix (or the identity
+# matrix).
+INIT_COVARIANCES_AT_ZERO = True
+
+# Whether to zero-debias the moving averages.
+ZERO_DEBIAS = True
+
+# Whether to initialize inverse (and other such matrices computed from the cov
+# matrices) to the zero matrix (or the identity matrix).
+INIT_INVERSES_AT_ZERO = True
+
+# When the number of inverses requested from a FisherFactor exceeds this value,
+# the inverses are computed using an eigenvalue decomposition.
+EIGENVALUE_DECOMPOSITION_THRESHOLD = 2
+
+# Numerical eigenvalues computed from covariance matrix estimates are clipped to
+# be at least as large as this value before they are used to compute inverses or
+# matrix powers. Must be nonnegative.
+EIGENVALUE_CLIPPING_THRESHOLD = 0.0
+
+# Used to subsample the flattened extracted image patches. The number of
+# outer products per row of the covariance matrix should not exceed this
+# value. This parameter is used only if `_SUB_SAMPLE_OUTER_PRODUCTS` is True.
+_MAX_NUM_OUTER_PRODUCTS_PER_COV_ROW = 1
+
+# Used to subsample the inputs passed to the extract image patches. The batch
+# size of number of inputs to extract image patches is multiplied by this
+# factor. This parameter is used only if `_SUB_SAMPLE_INPUTS` is True.
+_INPUTS_TO_EXTRACT_PATCHES_FACTOR = 0.5
+
+# If True, then subsamples the tensor passed to compute the covariance matrix.
+_SUB_SAMPLE_OUTER_PRODUCTS = False
+
+# If True, then subsamples the tensor passed to compute the covariance matrix.
+_SUB_SAMPLE_INPUTS = False
+
+# TOWER_STRATEGY can be one of "concat" or "separate". If "concat", the data
+# passed to the factors from the blocks will be concatenated across towers
+# (lazily via PartitionedTensor objects). Otherwise a tuple of tensors over
+# towers will be passed in, and the factors will iterate over this and do the
+# cov computations separately for each one, averaging the results together.
+TOWER_STRATEGY = "concat"
+
+
+def set_global_constants(init_covariances_at_zero=None,
+ zero_debias=None,
+ init_inverses_at_zero=None,
+ eigenvalue_decomposition_threshold=None,
+ eigenvalue_clipping_threshold=None,
+ max_num_outer_products_per_cov_row=None,
+ sub_sample_outer_products=None,
+ inputs_to_extract_patches_factor=None,
+ sub_sample_inputs=None,
+ tower_strategy=None):
+ """Sets various global constants used by the classes in this module."""
+ global INIT_COVARIANCES_AT_ZERO
+ global ZERO_DEBIAS
+ global INIT_INVERSES_AT_ZERO
+ global EIGENVALUE_DECOMPOSITION_THRESHOLD
+ global EIGENVALUE_CLIPPING_THRESHOLD
+ global _MAX_NUM_OUTER_PRODUCTS_PER_COV_ROW
+ global _SUB_SAMPLE_OUTER_PRODUCTS
+ global _INPUTS_TO_EXTRACT_PATCHES_FACTOR
+ global _SUB_SAMPLE_INPUTS
+ global TOWER_STRATEGY
+
+ if init_covariances_at_zero is not None:
+ INIT_COVARIANCES_AT_ZERO = init_covariances_at_zero
+ if zero_debias is not None:
+ ZERO_DEBIAS = zero_debias
+ if init_inverses_at_zero is not None:
+ INIT_INVERSES_AT_ZERO = init_inverses_at_zero
+ if eigenvalue_decomposition_threshold is not None:
+ EIGENVALUE_DECOMPOSITION_THRESHOLD = eigenvalue_decomposition_threshold
+ if eigenvalue_clipping_threshold is not None:
+ EIGENVALUE_CLIPPING_THRESHOLD = eigenvalue_clipping_threshold
+ if max_num_outer_products_per_cov_row is not None:
+ _MAX_NUM_OUTER_PRODUCTS_PER_COV_ROW = max_num_outer_products_per_cov_row
+ if sub_sample_outer_products is not None:
+ _SUB_SAMPLE_OUTER_PRODUCTS = sub_sample_outer_products
+ if inputs_to_extract_patches_factor is not None:
+ _INPUTS_TO_EXTRACT_PATCHES_FACTOR = inputs_to_extract_patches_factor
+ if sub_sample_inputs is not None:
+ _SUB_SAMPLE_INPUTS = sub_sample_inputs
+ if tower_strategy is not None:
+ TOWER_STRATEGY = tower_strategy
+
+
+def inverse_initializer(shape, dtype, partition_info=None): # pylint: disable=unused-argument
+ if INIT_INVERSES_AT_ZERO:
+ return array_ops.zeros(shape, dtype=dtype)
+ return linalg_ops.eye(num_rows=shape[0], dtype=dtype)
+
+
+def covariance_initializer(shape, dtype, partition_info=None): # pylint: disable=unused-argument
+ if INIT_COVARIANCES_AT_ZERO:
+ return array_ops.zeros(shape, dtype=dtype)
+ return linalg_ops.eye(num_rows=shape[0], dtype=dtype)
+
+
+def diagonal_covariance_initializer(shape, dtype, partition_info=None): # pylint: disable=unused-argument
+ if INIT_COVARIANCES_AT_ZERO:
+ return array_ops.zeros(shape, dtype=dtype)
+ return array_ops.ones(shape, dtype=dtype)
+
+
+@contextlib.contextmanager
+def place_on_device(device):
+ if device is not None and len(device):
+ with tf_ops.device(device):
+ yield
+ else:
+ yield
+
+
+def compute_cov(tensor, tensor_right=None, normalizer=None):
+ """Compute the empirical second moment of the rows of a 2D Tensor.
+
+ This function is meant to be applied to random matrices for which the true row
+ mean is zero, so that the true second moment equals the true covariance.
+
+ Args:
+ tensor: A 2D Tensor.
+ tensor_right: An optional 2D Tensor. If provided, this function computes
+ the matrix product tensor^T * tensor_right instead of tensor^T * tensor.
+ normalizer: optional scalar for the estimator (by default, the normalizer is
+ the number of rows of tensor).
+
+ Returns:
+ A square 2D Tensor with as many rows/cols as the number of input columns.
+ """
+ if normalizer is None:
+ normalizer = array_ops.shape(tensor)[0]
+ if tensor_right is None:
+ cov = (
+ math_ops.matmul(tensor, tensor, transpose_a=True) / math_ops.cast(
+ normalizer, tensor.dtype))
+ return (cov + array_ops.transpose(cov)) / math_ops.cast(2.0, cov.dtype)
+ else:
+ return (math_ops.matmul(tensor, tensor_right, transpose_a=True) /
+ math_ops.cast(normalizer, tensor.dtype))
+
+
+def append_homog(tensor):
+ """Appends a homogeneous coordinate to the last dimension of a Tensor.
+
+ Args:
+ tensor: A Tensor.
+
+ Returns:
+ A Tensor identical to the input but one larger in the last dimension. The
+ new entries are filled with ones.
+ """
+ rank = len(tensor.shape.as_list())
+ shape = array_ops.concat([array_ops.shape(tensor)[:-1], [1]], axis=0)
+ ones = array_ops.ones(shape, dtype=tensor.dtype)
+ return array_ops.concat([tensor, ones], axis=rank - 1)
+
+
+def scope_string_from_params(params):
+ """Builds a variable scope string name from the given parameters.
+
+ Supported parameters are:
+ * tensors
+ * booleans
+ * ints
+ * strings
+ * depth-1 tuples/lists of ints
+ * any depth tuples/lists of tensors
+ Other parameter types will throw an error.
+
+ Args:
+ params: A parameter or list of parameters.
+
+ Returns:
+ A string to use for the variable scope.
+
+ Raises:
+ ValueError: if params includes an unsupported type.
+ """
+ params = params if isinstance(params, (tuple, list)) else (params,)
+
+ name_parts = []
+ for param in params:
+ if param is None:
+ name_parts.append("None")
+ elif isinstance(param, (tuple, list)):
+ if all([isinstance(p, int) for p in param]):
+ name_parts.append("-".join([str(p) for p in param]))
+ else:
+ name_parts.append(scope_string_from_name(param))
+ elif isinstance(param, (str, int, bool)):
+ name_parts.append(str(param))
+ elif isinstance(param, (tf_ops.Tensor, variables.Variable)):
+ name_parts.append(scope_string_from_name(param))
+ elif isinstance(param, utils.PartitionedTensor):
+ name_parts.append(scope_string_from_name(param.tensors))
+ else:
+ raise ValueError("Encountered an unsupported param type {}".format(
+ type(param)))
+ return "_".join(name_parts)
+
+
+def scope_string_from_name(tensor):
+ if isinstance(tensor, (tuple, list)):
+ return "__".join([scope_string_from_name(t) for t in tensor])
+ # "gradients/add_4_grad/Reshape:0" -> "gradients_add_4_grad_Reshape"
+ return tensor.name.split(":")[0].replace("/", "_")
+
+
+def scalar_or_tensor_to_string(val):
+ return repr(val) if np.isscalar(val) else scope_string_from_name(val)
+
+
+def list_to_string(lst):
+ return "_".join(val if isinstance(val, six.string_types)
+ else scalar_or_tensor_to_string(val) for val in lst)
+
+
+def graph_func_to_id(func):
+ """Returns a hashable object that represents func's computation."""
+ # TODO(b/74201126): replace with Topohash of func's output
+ return func.func_id
+
+
+def graph_func_to_string(func):
+ # TODO(b/74201126): replace with Topohash of func's output
+ return list_to_string(func.func_id)
+
+
+def _subsample_for_cov_computation(array, name=None):
+ """Subsamples the first dimension of the array.
+
+ `array`(A) is a tensor of shape `[batch_size, dim_2]`. Then the covariance
+ matrix(A^TA) is of shape `dim_2 ** 2`. Subsample only if the number of outer
+ products per row of the covariance matrix is greater than
+ `_MAX_NUM_OUTER_PRODUCTS_PER_COV_ROW`.
+
+ Args:
+ array: Tensor, of shape `[batch_size, dim_2]`.
+ name: `string`, Default(None)
+
+ Returns:
+ A tensor of shape `[max_samples, dim_2]`.
+
+ Raises:
+ ValueError: If array's is not matrix-shaped.
+ ValueError: If array's batch_size cannot be inferred.
+
+ """
+ with tf_ops.name_scope(name, "subsample", [array]):
+ array = tf_ops.convert_to_tensor(array)
+ if len(array.shape) != 2:
+ raise ValueError("Input param array must be a matrix.")
+
+ batch_size = array.shape.as_list()[0]
+ if batch_size is None:
+ raise ValueError("Unable to get batch_size from input param array.")
+
+ num_cov_rows = array.shape.as_list()[-1]
+ max_batch_size = int(_MAX_NUM_OUTER_PRODUCTS_PER_COV_ROW * num_cov_rows)
+ if batch_size <= max_batch_size:
+ return array
+
+ return _random_tensor_gather(array, max_batch_size)
+
+
+def _random_tensor_gather(array, max_size):
+ """Generates a random set of indices and gathers the value at the indices.
+
+ Args:
+ array: Tensor, of shape `[batch_size, dim_2]`.
+ max_size: int, Number of indices to sample.
+
+ Returns:
+ A tensor of shape `[max_size, ...]`.
+ """
+ batch_size = array.shape.as_list()[0]
+ indices = random_ops.random_shuffle(math_ops.range(0, batch_size))[:max_size]
+ return array_ops.gather(array, indices)
+
+
+@six.add_metaclass(abc.ABCMeta)
+class FisherFactor(object):
+ """Base class for objects modeling factors of approximate Fisher blocks.
+
+ A FisherFactor represents part of an approximate Fisher Information matrix.
+ For example, one approximation to the Fisher uses the Kronecker product of two
+ FisherFactors A and B, F = kron(A, B). FisherFactors are composed with
+ FisherBlocks to construct a block-diagonal approximation to the full Fisher.
+
+ FisherFactors are backed by a single, non-trainable variable that is updated
+ by running FisherFactor.make_covariance_update_op(). The shape and type of
+ this variable is implementation specific.
+
+ Note that for blocks that aren't based on approximations, a 'factor' can
+ be the entire block itself, as is the case for the diagonal and full
+ representations.
+ """
+
+ def __init__(self):
+ self._cov = None
+
+ @abc.abstractproperty
+ def _var_scope(self):
+ """Variable scope for this FisherFactor instance.
+
+ Returns:
+ string that unique identifies this FisherFactor instance.
+ """
+ pass
+
+ @property
+ def name(self):
+ return self._var_scope
+
+ @abc.abstractproperty
+ def _cov_shape(self):
+ """The shape of the variable backing this FisherFactor."""
+ pass
+
+ @abc.abstractproperty
+ def _num_sources(self):
+ """The number of things to sum over when updating covariance variable.
+
+ The default make_covariance_update_op function will call _compute_new_cov
+ with indices ranging from 0 to _num_sources-1. The typical situation is
+ where the factor wants to sum the statistics it computes over multiple
+ backpropped "gradients" (typically passed in via "tensors" or
+ "outputs_grads" arguments).
+ """
+ pass
+
+ @abc.abstractproperty
+ def _num_towers(self):
+ pass
+
+ @abc.abstractproperty
+ def _dtype(self):
+ """dtype for variable backing this factor."""
+ pass
+
+ @property
+ def _cov_initializer(self):
+ """Function for initializing covariance variable."""
+ return covariance_initializer
+
+ def instantiate_cov_variables(self):
+ """Makes the internal cov variable(s)."""
+ assert self._cov is None
+ with variable_scope.variable_scope(self._var_scope):
+ self._cov = variable_scope.get_variable(
+ "cov",
+ initializer=self._cov_initializer,
+ shape=self._cov_shape,
+ trainable=False,
+ dtype=self._dtype)
+
+ @abc.abstractmethod
+ def _compute_new_cov(self, source, tower):
+ """Computes minibatch-estimated covariance for a single source.
+
+ Args:
+ source: int in [0, self._num_sources). Which source to use when computing
+ the cov update.
+ tower: int in [0, self._num_towers). Which tower to use when computing
+ the cov update.
+
+ Returns:
+ Tensor of same shape as self.get_cov().
+ """
+ pass
+
+ def make_covariance_update_op(self, ema_decay):
+ """Constructs and returns the covariance update Op.
+
+ Args:
+ ema_decay: The exponential moving average decay (float or Tensor).
+ Returns:
+ An Op for updating the covariance Variable referenced by _cov.
+ """
+ new_cov_contribs = []
+ for source in range(self._num_sources):
+ for tower in range(self._num_towers):
+ device = (self._get_data_device(tower)
+ if TOWER_STRATEGY == "separate" else None)
+ with place_on_device(device):
+ new_cov_contribs.append(self._compute_new_cov(source, tower))
+
+ new_cov = math_ops.add_n(new_cov_contribs) / float(self._num_towers)
+
+ # Compute average of 'new_cov' across all TPU cores. On a TPU, each
+ # instance of 'new_cov' will be based on a different minibatch. This ensures
+ # that by the end of assign_moving_average(), all TPU cores see the same
+ # value for self._cov.
+ #
+ # Other implementations of make_covariance_update_op() that accumulate
+ # statistics in other variables should mimic this behavior.
+ if utils.on_tpu():
+ new_cov = utils.cross_replica_mean(new_cov)
+
+ return moving_averages.assign_moving_average(
+ self._cov, new_cov, ema_decay, zero_debias=ZERO_DEBIAS)
+
+ @abc.abstractmethod
+ def _get_data_device(self, tower):
+ pass
+
+ @abc.abstractmethod
+ def instantiate_inv_variables(self):
+ """Makes the internal "inverse" variable(s)."""
+ pass
+
+ @abc.abstractmethod
+ def make_inverse_update_ops(self):
+ """Create and return update ops corresponding to registered computations."""
+ pass
+
+ def get_cov(self):
+ return self._cov
+
+ @abc.abstractmethod
+ def get_cov_as_linear_operator(self):
+ pass
+
+ @abc.abstractmethod
+ def register_matpower(self, exp, damping_func):
+ pass
+
+ @abc.abstractmethod
+ def register_cholesky(self, damping_func):
+ pass
+
+ @abc.abstractmethod
+ def register_cholesky_inverse(self, damping_func):
+ pass
+
+ @abc.abstractmethod
+ def get_matpower(self, exp, damping_func):
+ pass
+
+ @abc.abstractmethod
+ def get_cholesky(self, damping_func):
+ pass
+
+ @abc.abstractmethod
+ def get_cholesky_inverse(self, damping_func):
+ pass
+
+
+class DenseSquareMatrixFactor(FisherFactor):
+ """Base class for FisherFactors that are stored as dense square matrices.
+
+ This class explicitly calculates and stores inverses of their `cov` matrices,
+ which must be square dense matrices.
+
+ Subclasses must implement the _compute_new_cov method, and the _var_scope and
+ _cov_shape properties.
+ """
+
+ # TODO(b/69108481): This class (and its subclasses) should be refactored to
+ # serve the matrix quantities it computes as both (potentially stale)
+ # variables, updated by the inverse update ops, and fresh values stored in
+ # tensors that recomputed once every session.run() call. Currently matpower
+ # and damp_inverse have the former behavior, while eigendecomposition has
+ # the latter.
+
+ def __init__(self):
+ self._matpower_by_exp_and_damping = {} # { (float, hashable): variable }
+ self._matpower_registrations = set() # { (float, hashable) }
+ self._eigendecomp = None
+ self._damping_funcs_by_id = {} # {hashable: lambda}
+
+ self._cholesky_registrations = set() # { hashable }
+ self._cholesky_inverse_registrations = set() # { hashable }
+
+ self._cholesky_by_damping = {} # { hashable: variable }
+ self._cholesky_inverse_by_damping = {} # { hashable: variable }
+
+ super(DenseSquareMatrixFactor, self).__init__()
+
+ def get_cov_as_linear_operator(self):
+ assert self.get_cov().shape.ndims == 2
+ return lo.LinearOperatorFullMatrix(self.get_cov(),
+ is_self_adjoint=True,
+ is_square=True)
+
+ def _register_damping(self, damping_func):
+ damping_id = graph_func_to_id(damping_func)
+ if damping_id not in self._damping_funcs_by_id:
+ self._damping_funcs_by_id[damping_id] = damping_func
+ return damping_id
+
+ def register_inverse(self, damping_func):
+ # Just for backwards compatibility of some old code and tests
+ self.register_matpower(-1, damping_func)
+
+ def register_matpower(self, exp, damping_func):
+ """Registers a matrix power to be maintained and served on demand.
+
+ This creates a variable and signals make_inverse_update_ops to make the
+ corresponding update op. The variable can be read via the method
+ get_matpower.
+
+ Args:
+ exp: float. The exponent to use in the matrix power.
+ damping_func: A function that computes a 0-D Tensor or a float which will
+ be the damping value used. i.e. damping = damping_func().
+ """
+ if exp == 1.0:
+ return
+
+ damping_id = self._register_damping(damping_func)
+
+ if (exp, damping_id) not in self._matpower_registrations:
+ self._matpower_registrations.add((exp, damping_id))
+
+ def register_cholesky(self, damping_func):
+ """Registers a Cholesky factor to be maintained and served on demand.
+
+ This creates a variable and signals make_inverse_update_ops to make the
+ corresponding update op. The variable can be read via the method
+ get_cholesky.
+
+ Args:
+ damping_func: A function that computes a 0-D Tensor or a float which will
+ be the damping value used. i.e. damping = damping_func().
+ """
+ damping_id = self._register_damping(damping_func)
+
+ if damping_id not in self._cholesky_registrations:
+ self._cholesky_registrations.add(damping_id)
+
+ def register_cholesky_inverse(self, damping_func):
+ """Registers an inverse Cholesky factor to be maintained/served on demand.
+
+ This creates a variable and signals make_inverse_update_ops to make the
+ corresponding update op. The variable can be read via the method
+ get_cholesky_inverse.
+
+ Args:
+ damping_func: A function that computes a 0-D Tensor or a float which will
+ be the damping value used. i.e. damping = damping_func().
+ """
+ damping_id = self._register_damping(damping_func)
+
+ if damping_id not in self._cholesky_inverse_registrations:
+ self._cholesky_inverse_registrations.add(damping_id)
+
+ def instantiate_inv_variables(self):
+ """Makes the internal "inverse" variable(s)."""
+
+ for (exp, damping_id) in self._matpower_registrations:
+ exp_string = scalar_or_tensor_to_string(exp)
+ damping_func = self._damping_funcs_by_id[damping_id]
+ damping_string = graph_func_to_string(damping_func)
+ with variable_scope.variable_scope(self._var_scope):
+ matpower = variable_scope.get_variable(
+ "matpower_exp{}_damp{}".format(exp_string, damping_string),
+ initializer=inverse_initializer,
+ shape=self._cov_shape,
+ trainable=False,
+ dtype=self._dtype)
+ assert (exp, damping_id) not in self._matpower_by_exp_and_damping
+ self._matpower_by_exp_and_damping[(exp, damping_id)] = matpower
+
+ for damping_id in self._cholesky_registrations:
+ damping_func = self._damping_funcs_by_id[damping_id]
+ damping_string = graph_func_to_string(damping_func)
+ with variable_scope.variable_scope(self._var_scope):
+ chol = variable_scope.get_variable(
+ "cholesky_damp{}".format(damping_string),
+ initializer=inverse_initializer,
+ shape=self._cov_shape,
+ trainable=False,
+ dtype=self._dtype)
+ assert damping_id not in self._cholesky_by_damping
+ self._cholesky_by_damping[damping_id] = chol
+
+ for damping_id in self._cholesky_inverse_registrations:
+ damping_func = self._damping_funcs_by_id[damping_id]
+ damping_string = graph_func_to_string(damping_func)
+ with variable_scope.variable_scope(self._var_scope):
+ cholinv = variable_scope.get_variable(
+ "cholesky_inverse_damp{}".format(damping_string),
+ initializer=inverse_initializer,
+ shape=self._cov_shape,
+ trainable=False,
+ dtype=self._dtype)
+ assert damping_id not in self._cholesky_inverse_by_damping
+ self._cholesky_inverse_by_damping[damping_id] = cholinv
+
+ def make_inverse_update_ops(self):
+ """Create and return update ops corresponding to registered computations."""
+ ops = []
+
+ num_inverses = sum(1 for (exp, _) in self._matpower_by_exp_and_damping
+ if exp == -1)
+
+ num_other_matpower = len(self._matpower_by_exp_and_damping) - num_inverses
+
+ other_matrix_power_registered = num_other_matpower >= 1
+
+ use_eig = (
+ self._eigendecomp or other_matrix_power_registered or
+ num_inverses >= EIGENVALUE_DECOMPOSITION_THRESHOLD)
+
+ # We precompute these so we don't need to evaluate them multiple times (for
+ # each matrix power that uses them)
+ damping_value_by_id = {damping_id: math_ops.cast(
+ self._damping_funcs_by_id[damping_id](), self._dtype)
+ for damping_id in self._damping_funcs_by_id}
+
+ if use_eig:
+ eigenvalues, eigenvectors = self.get_eigendecomp() # pylint: disable=unpacking-non-sequence
+
+ for (exp, damping_id), matpower in (
+ self._matpower_by_exp_and_damping.items()):
+ damping = damping_value_by_id[damping_id]
+ ops.append(
+ matpower.assign(
+ math_ops.matmul(eigenvectors *
+ (eigenvalues + damping)**exp,
+ array_ops.transpose(eigenvectors))))
+ # These ops share computation and should be run on a single device.
+ ops = [control_flow_ops.group(*ops)]
+ else:
+ for (exp, damping_id), matpower in (
+ self._matpower_by_exp_and_damping.items()):
+ assert exp == -1
+ damping = damping_value_by_id[damping_id]
+ ops.append(matpower.assign(utils.posdef_inv(self.get_cov(), damping)))
+
+ # TODO(b/77902055): If inverses are being computed with Cholesky's
+ # we can share the work. Instead this code currently just computes the
+ # Cholesky a second time. It does at least share work between requests for
+ # Cholesky's and Cholesky inverses with the same damping id.
+ for damping_id, cholesky_inv in self._cholesky_inverse_by_damping.items():
+ cholesky_ops = []
+
+ damping = damping_value_by_id[damping_id]
+ cholesky_value = utils.cholesky(self.get_cov(), damping)
+
+ if damping_id in self._cholesky_by_damping:
+ cholesky = self._cholesky_by_damping[damping_id]
+ cholesky_ops.append(cholesky.assign(cholesky_value))
+
+ identity = linalg_ops.eye(cholesky_value.shape.as_list()[0],
+ dtype=cholesky_value.dtype)
+ cholesky_inv_value = linalg_ops.matrix_triangular_solve(cholesky_value,
+ identity)
+ cholesky_ops.append(cholesky_inv.assign(cholesky_inv_value))
+
+ ops.append(control_flow_ops.group(*cholesky_ops))
+
+ for damping_id, cholesky in self._cholesky_by_damping.items():
+ if damping_id not in self._cholesky_inverse_by_damping:
+ damping = damping_value_by_id[damping_id]
+ cholesky_value = utils.cholesky(self.get_cov(), damping)
+ ops.append(cholesky.assign(cholesky_value))
+
+ self._eigendecomp = False
+ return ops
+
+ def get_inverse(self, damping_func):
+ # Just for backwards compatibility of some old code and tests
+ return self.get_matpower(-1, damping_func)
+
+ def get_matpower(self, exp, damping_func):
+ # Note that this function returns a variable which gets updated by the
+ # inverse ops. It may be stale / inconsistent with the latest value of
+ # get_cov().
+ if exp != 1:
+ damping_id = graph_func_to_id(damping_func)
+ matpower = self._matpower_by_exp_and_damping[(exp, damping_id)]
+ else:
+ matpower = self.get_cov()
+ identity = linalg_ops.eye(matpower.shape.as_list()[0],
+ dtype=matpower.dtype)
+ matpower += math_ops.cast(damping_func(), dtype=matpower.dtype)*identity
+
+ assert matpower.shape.ndims == 2
+ return lo.LinearOperatorFullMatrix(matpower,
+ is_non_singular=True,
+ is_self_adjoint=True,
+ is_positive_definite=True,
+ is_square=True)
+
+ def get_cholesky(self, damping_func):
+ # Note that this function returns a variable which gets updated by the
+ # inverse ops. It may be stale / inconsistent with the latest value of
+ # get_cov().
+ damping_id = graph_func_to_id(damping_func)
+ cholesky = self._cholesky_by_damping[damping_id]
+ assert cholesky.shape.ndims == 2
+ return lo.LinearOperatorFullMatrix(cholesky,
+ is_non_singular=True,
+ is_square=True)
+
+ def get_cholesky_inverse(self, damping_func):
+ # Note that this function returns a variable which gets updated by the
+ # inverse ops. It may be stale / inconsistent with the latest value of
+ # get_cov().
+ damping_id = graph_func_to_id(damping_func)
+ cholesky_inv = self._cholesky_inverse_by_damping[damping_id]
+ assert cholesky_inv.shape.ndims == 2
+ return lo.LinearOperatorFullMatrix(cholesky_inv,
+ is_non_singular=True,
+ is_square=True)
+
+ def get_eigendecomp(self):
+ """Creates or retrieves eigendecomposition of self._cov."""
+ # Unlike get_matpower this doesn't retrieve a stored variable, but instead
+ # always computes a fresh version from the current value of get_cov().
+ if not self._eigendecomp:
+ eigenvalues, eigenvectors = linalg_ops.self_adjoint_eig(self.get_cov())
+
+ # The matrix self._cov is positive semidefinite by construction, but the
+ # numerical eigenvalues could be negative due to numerical errors, so here
+ # we clip them to be at least FLAGS.eigenvalue_clipping_threshold
+ clipped_eigenvalues = math_ops.maximum(eigenvalues,
+ EIGENVALUE_CLIPPING_THRESHOLD)
+ self._eigendecomp = (clipped_eigenvalues, eigenvectors)
+
+ return self._eigendecomp
+
+
+class FullFactor(DenseSquareMatrixFactor):
+ """FisherFactor for a full matrix representation of the Fisher of a parameter.
+
+ Note that this uses the naive "square the sum estimator", and so is applicable
+ to any type of parameter in principle, but has very high variance.
+ """
+
+ def __init__(self,
+ params_grads,
+ batch_size):
+ self._batch_size = batch_size
+ self._params_grads = tuple(utils.ensure_sequence(params_grad)
+ for params_grad in params_grads)
+ super(FullFactor, self).__init__()
+
+ @property
+ def _var_scope(self):
+ return "ff_full_" + scope_string_from_params(
+ [self._params_grads, self._batch_size])
+
+ @property
+ def _cov_shape(self):
+ size = sum(param_grad.shape.num_elements()
+ for param_grad in self._params_grads[0])
+ return (size, size)
+
+ @property
+ def _num_sources(self):
+ return len(self._params_grads)
+
+ @property
+ def _num_towers(self):
+ return 1
+
+ @property
+ def _dtype(self):
+ return self._params_grads[0][0].dtype
+
+ def _compute_new_cov(self, source, tower):
+ assert tower == 0
+
+ # This will be a very basic rank 1 estimate
+ params_grads_flat = utils.tensors_to_column(self._params_grads[source])
+ return ((params_grads_flat * array_ops.transpose(
+ params_grads_flat)) / math_ops.cast(self._batch_size,
+ params_grads_flat.dtype))
+
+ def _get_data_device(self, tower):
+ return None
+
+
+class DiagonalFactor(FisherFactor):
+ """A base class for FisherFactors that use diagonal approximations.
+
+ A DiagonalFactor's covariance variable can be of any shape, but must contain
+ exactly one entry per parameter.
+ """
+
+ def __init__(self):
+ super(DiagonalFactor, self).__init__()
+
+ def get_cov_as_linear_operator(self):
+ assert self._matrix_diagonal.shape.ndims == 1
+ return lo.LinearOperatorDiag(self._matrix_diagonal,
+ is_self_adjoint=True,
+ is_square=True)
+
+ @property
+ def _cov_initializer(self):
+ return diagonal_covariance_initializer
+
+ @property
+ def _matrix_diagonal(self):
+ return array_ops.reshape(self.get_cov(), [-1])
+
+ def make_inverse_update_ops(self):
+ return []
+
+ def instantiate_inv_variables(self):
+ pass
+
+ def register_matpower(self, exp, damping_func):
+ pass
+
+ def register_cholesky(self, damping_func):
+ pass
+
+ def register_cholesky_inverse(self, damping_func):
+ pass
+
+ def get_matpower(self, exp, damping_func):
+ matpower_diagonal = (self._matrix_diagonal
+ + math_ops.cast(damping_func(), self._dtype))**exp
+ return lo.LinearOperatorDiag(matpower_diagonal,
+ is_non_singular=True,
+ is_self_adjoint=True,
+ is_positive_definite=True,
+ is_square=True)
+
+ def get_cholesky(self, damping_func):
+ return self.get_matpower(0.5, damping_func)
+
+ def get_cholesky_inverse(self, damping_func):
+ return self.get_matpower(-0.5, damping_func)
+
+
+class NaiveDiagonalFactor(DiagonalFactor):
+ """FisherFactor for a diagonal approximation of any type of param's Fisher.
+
+ Note that this uses the naive "square the sum estimator", and so is applicable
+ to any type of parameter in principle, but has very high variance.
+ """
+
+ def __init__(self,
+ params_grads,
+ batch_size):
+ """Initializes NaiveDiagonalFactor instance.
+
+ Args:
+ params_grads: Sequence of Tensors, each with same shape as parameters this
+ FisherFactor corresponds to. For example, the gradient of the loss with
+ respect to parameters.
+ batch_size: int or 0-D Tensor. Size
+ """
+ self._params_grads = tuple(utils.ensure_sequence(params_grad)
+ for params_grad in params_grads)
+ self._batch_size = batch_size
+ super(NaiveDiagonalFactor, self).__init__()
+
+ @property
+ def _var_scope(self):
+ return "ff_naivediag_" + scope_string_from_params(
+ [self._params_grads, self._batch_size])
+
+ @property
+ def _cov_shape(self):
+ size = sum(param_grad.shape.num_elements()
+ for param_grad in self._params_grads[0])
+ return [size, 1]
+
+ @property
+ def _num_sources(self):
+ return len(self._params_grads)
+
+ @property
+ def _num_towers(self):
+ return 1
+
+ @property
+ def _dtype(self):
+ return self._params_grads[0][0].dtype
+
+ def _compute_new_cov(self, source, tower):
+ assert tower == 0
+
+ params_grads_flat = utils.tensors_to_column(self._params_grads[source])
+ return (math_ops.square(params_grads_flat) / math_ops.cast(
+ self._batch_size, params_grads_flat.dtype))
+
+ def _get_data_device(self, tower):
+ return None
+
+
+class EmbeddingInputKroneckerFactor(DiagonalFactor):
+ r"""FisherFactor for input to an embedding layer.
+
+ Given input_ids = [batch_size, input_size] representing indices into an
+ [vocab_size, embedding_size] embedding matrix, approximate input covariance by
+ a diagonal matrix,
+
+ Cov(input_ids, input_ids) =
+ (1/batch_size) sum_{i} diag(n_hot(input[i]) ** 2).
+
+ where n_hot() constructs an n-hot binary vector and diag() constructs a
+ diagonal matrix of size [vocab_size, vocab_size].
+ """
+
+ def __init__(self, input_ids, vocab_size, dtype=None):
+ """Instantiate EmbeddingInputKroneckerFactor.
+
+ Args:
+ input_ids: List of Tensors of shape [batch_size, input_size] and dtype
+ int32. Indices into embedding matrix. List index is tower.
+ vocab_size: int or 0-D Tensor. Maximum value for entries in 'input_ids'.
+ dtype: dtype for covariance statistics. Must be a floating point type.
+ Defaults to float32.
+ """
+ self._input_ids = input_ids
+ self._vocab_size = vocab_size
+ self._cov_dtype = dtype or dtypes.float32
+
+ super(EmbeddingInputKroneckerFactor, self).__init__()
+
+ @property
+ def _var_scope(self):
+ return "ff_diag_embedding_" + scope_string_from_params(self._input_ids)
+
+ @property
+ def _cov_shape(self):
+ return [self._vocab_size]
+
+ @property
+ def _num_sources(self):
+ return 1
+
+ @property
+ def _num_towers(self):
+ return len(self._input_ids)
+
+ @property
+ def _dtype(self):
+ return self._cov_dtype
+
+ def _compute_new_cov(self, source, tower):
+ assert source == 0
+
+ input_ids = self._input_ids[tower]
+
+ if len(input_ids.shape) > 2:
+ raise ValueError(
+ "Input to embeddings must have rank <= 2. Found rank %d." % len(
+ input_ids.shape))
+
+ batch_size = array_ops.shape(input_ids)[0]
+
+ # Transform indices into one-hot vectors.
+ #
+ # TODO(b/72714822): There must be a faster way to construct the diagonal
+ # covariance matrix! This operation is O(batch_size * vocab_size), where
+ # it should be O(batch_size * input_size).
+ flat_input_ids = array_ops.reshape(input_ids, [-1])
+ one_hots = array_ops.one_hot(flat_input_ids,
+ self._vocab_size) # [?, vocab_size]
+
+ # Take average across examples. Note that, because all entries have
+ # magnitude zero or one, there's no need to square the entries.
+ #
+ # TODO(b/72714822): Support for SparseTensor, other kinds of aggregation
+ # within an example such as average.
+ #
+ # TODO(b/72714822): Support for partitioned embeddings.
+ new_cov = math_ops.reduce_sum(one_hots, axis=0) # [vocab_size]
+ new_cov /= math_ops.cast(batch_size, new_cov.dtype)
+
+ return new_cov
+
+ def _get_data_device(self, tower):
+ return self._input_ids[tower].device
+
+
+class FullyConnectedDiagonalFactor(DiagonalFactor):
+ r"""FisherFactor for a diagonal approx of a fully-connected layer's Fisher.
+
+ Given in = [batch_size, input_size] and out_grad = [batch_size, output_size],
+ approximates the covariance as,
+
+ Cov(in, out) = (1/batch_size) sum_{i} outer(in[i], out_grad[i]) ** 2.0
+
+ where the square is taken element-wise.
+ """
+
+ def __init__(self,
+ inputs,
+ outputs_grads,
+ has_bias=False):
+ """Instantiate FullyConnectedDiagonalFactor.
+
+ Args:
+ inputs: List of Tensors of shape [batch_size, input_size]. Inputs to this
+ layer. List index is towers.
+ outputs_grads: List of Tensors, each of shape [batch_size, output_size],
+ which are the gradients of the loss with respect to the layer's
+ outputs. First index is source, second is tower.
+
+ has_bias: bool. If True, append '1' to each input.
+ """
+ self._inputs = inputs
+ self._has_bias = has_bias
+ self._outputs_grads = outputs_grads
+ self._squared_inputs = None
+
+ super(FullyConnectedDiagonalFactor, self).__init__()
+
+ @property
+ def _var_scope(self):
+ return "ff_diagfc_" + scope_string_from_params(
+ tuple(self._inputs) + tuple(nest.flatten(self._outputs_grads)))
+
+ @property
+ def _cov_shape(self):
+ input_size = self._inputs[0].shape[1] + self._has_bias
+ output_size = self._outputs_grads[0][0].shape[1]
+ return [input_size, output_size]
+
+ @property
+ def _num_sources(self):
+ return len(self._outputs_grads)
+
+ @property
+ def _num_towers(self):
+ return len(self._inputs)
+
+ @property
+ def _dtype(self):
+ return self._outputs_grads[0][0].dtype
+
+ def make_covariance_update_op(self, ema_decay):
+
+ self._squared_inputs = []
+ for tower in range(self._num_towers):
+ inputs = self._inputs[tower]
+
+ with place_on_device(self._get_data_device(tower)):
+ if self._has_bias:
+ inputs = append_homog(inputs)
+ self._squared_inputs.append(math_ops.square(inputs))
+
+ return super(FullyConnectedDiagonalFactor, self).make_covariance_update_op(
+ ema_decay)
+
+ def _compute_new_cov(self, source, tower):
+ batch_size = array_ops.shape(self._squared_inputs[tower])[0]
+ outputs_grad = self._outputs_grads[source][tower]
+
+ # The well-known special formula that uses the fact that the entry-wise
+ # square of an outer product is the outer-product of the entry-wise squares.
+ # The gradient is the outer product of the input and the output gradients,
+ # so we just square both and then take their outer-product.
+ new_cov = math_ops.matmul(
+ self._squared_inputs[tower],
+ math_ops.square(outputs_grad),
+ transpose_a=True)
+ new_cov /= math_ops.cast(batch_size, new_cov.dtype)
+ return new_cov
+
+ def _get_data_device(self, tower):
+ return self._inputs[tower].device
+
+
+class ConvDiagonalFactor(DiagonalFactor):
+ """FisherFactor for a diagonal approx of a convolutional layer's Fisher."""
+
+ def __init__(self,
+ inputs,
+ outputs_grads,
+ filter_shape,
+ strides,
+ padding,
+ data_format=None,
+ dilations=None,
+ has_bias=False):
+ """Creates a ConvDiagonalFactor object.
+
+ Args:
+ inputs: List of Tensors of shape [batch_size, height, width, in_channels].
+ Input activations to this layer. List index is towers.
+ outputs_grads: List of Tensors, each of shape [batch_size,
+ height, width, out_channels], which are the gradients of the loss
+ with respect to the layer's outputs. First index is source, second
+ index is tower.
+ filter_shape: Tuple of 4 ints: (kernel_height, kernel_width, in_channels,
+ out_channels). Represents shape of kernel used in this layer.
+ strides: The stride size in this layer (1-D Tensor of length 4).
+ padding: The padding in this layer (1-D of Tensor length 4).
+ data_format: None or str. Format of conv2d inputs.
+ dilations: None or tuple of 4 ints.
+ has_bias: Python bool. If True, the layer is assumed to have a bias
+ parameter in addition to its filter parameter.
+
+ Raises:
+ ValueError: If inputs, output_grads, and filter_shape do not agree on
+ in_channels or out_channels.
+ ValueError: If strides, dilations are not length-4 lists of ints.
+ ValueError: If data_format does not put channel last.
+ """
+ if not utils.is_data_format_channel_last(data_format):
+ raise ValueError("Channel must be last.")
+ if any(input_.shape.ndims != 4 for input_ in inputs):
+ raise ValueError("inputs must be a list of 4-D Tensors.")
+ if any(input_.shape.as_list()[-1] != filter_shape[-2] for input_ in inputs):
+ raise ValueError("inputs and filter_shape must agree on in_channels.")
+ for i, outputs_grad in enumerate(outputs_grads):
+ if any(output_grad.shape.ndims != 4 for output_grad in outputs_grad):
+ raise ValueError("outputs[%d] must be 4-D Tensor." % i)
+ if any(output_grad.shape.as_list()[-1] != filter_shape[-1]
+ for output_grad in outputs_grad):
+ raise ValueError(
+ "outputs[%d] and filter_shape must agree on out_channels." % i)
+ if len(strides) != 4:
+ raise ValueError("strides must be length-4 list of ints.")
+ if dilations is not None and len(dilations) != 4:
+ raise ValueError("dilations must be length-4 list of ints.")
+
+ self._inputs = inputs
+ self._outputs_grads = outputs_grads
+ self._filter_shape = filter_shape
+ self._strides = strides
+ self._padding = padding
+ self._data_format = data_format
+ self._dilations = dilations
+ self._has_bias = has_bias
+ self._patches = None
+
+ super(ConvDiagonalFactor, self).__init__()
+
+ @property
+ def _var_scope(self):
+ return "ff_convdiag_" + scope_string_from_params(
+ tuple(self._inputs) + tuple(nest.flatten(self._outputs_grads)))
+
+ @property
+ def _cov_shape(self):
+ filter_height, filter_width, in_channels, out_channels = self._filter_shape
+ return [
+ filter_height * filter_width * in_channels + self._has_bias,
+ out_channels
+ ]
+
+ @property
+ def _num_sources(self):
+ return len(self._outputs_grads)
+
+ @property
+ def _num_towers(self):
+ return len(self._inputs)
+
+ @property
+ def _dtype(self):
+ return self._inputs[0].dtype
+
+ def make_covariance_update_op(self, ema_decay):
+ filter_height, filter_width, _, _ = self._filter_shape
+
+ # TODO(b/64144716): there is potential here for a big savings in terms
+ # of memory use.
+ if self._dilations is None:
+ rates = (1, 1, 1, 1)
+ else:
+ rates = tuple(self._dilations)
+
+ self._patches = []
+ for tower in range(self._num_towers):
+ with place_on_device(self._get_data_device(tower)):
+ patches = array_ops.extract_image_patches(
+ self._inputs[tower],
+ ksizes=[1, filter_height, filter_width, 1],
+ strides=self._strides,
+ rates=rates,
+ padding=self._padding)
+
+ if self._has_bias:
+ patches = append_homog(patches)
+
+ self._patches.append(patches)
+
+ return super(ConvDiagonalFactor, self).make_covariance_update_op(ema_decay)
+
+ def _compute_new_cov(self, source, tower):
+ patches = self._patches[tower]
+ batch_size = array_ops.shape(patches)[0]
+ outputs_grad = self._outputs_grads[source][tower]
+
+ new_cov = self._convdiag_sum_of_squares(patches, outputs_grad)
+ new_cov /= math_ops.cast(batch_size, new_cov.dtype)
+
+ return new_cov
+
+ def _convdiag_sum_of_squares(self, patches, outputs_grad):
+ # This computes the sum of the squares of the per-training-case "gradients".
+ # It does this simply by computing a giant tensor containing all of these,
+ # doing an entry-wise square, and them summing along the batch dimension.
+ case_wise_gradients = special_math_ops.einsum("bijk,bijl->bkl", patches,
+ outputs_grad)
+ return math_ops.reduce_sum(math_ops.square(case_wise_gradients), axis=0)
+
+ def _get_data_device(self, tower):
+ return self._inputs[tower].device
+
+
+class FullyConnectedKroneckerFactor(DenseSquareMatrixFactor):
+ """Kronecker factor for the input or output side of a fully-connected layer.
+ """
+
+ def __init__(self,
+ tensors,
+ has_bias=False):
+ """Instantiate FullyConnectedKroneckerFactor.
+
+ Args:
+ tensors: List of list of Tensors, each of shape [batch_size, n]. The
+ Tensors are typically either a layer's inputs or its output's gradients.
+ The first list index is source, the second is tower.
+ has_bias: bool. If True, append '1' to each row.
+ """
+ # The tensor argument is either a tensor of input activations or a tensor of
+ # output pre-activation gradients.
+ self._has_bias = has_bias
+ self._tensors = tensors
+ super(FullyConnectedKroneckerFactor, self).__init__()
+
+ @property
+ def _var_scope(self):
+ return "ff_fckron_" + scope_string_from_params(
+ tuple(nest.flatten(self._tensors)) + (self._has_bias,))
+
+ @property
+ def _cov_shape(self):
+ size = self._tensors[0][0].shape[1] + self._has_bias
+ return [size, size]
+
+ @property
+ def _num_sources(self):
+ return len(self._tensors)
+
+ @property
+ def _num_towers(self):
+ return len(self._tensors[0])
+
+ @property
+ def _dtype(self):
+ return self._tensors[0][0].dtype
+
+ def _compute_new_cov(self, source, tower):
+ tensor = self._tensors[source][tower]
+ if self._has_bias:
+ tensor = append_homog(tensor)
+ return compute_cov(tensor)
+
+ def _get_data_device(self, tower):
+ return self._tensors[0][tower].device
+
+
+class ConvInputKroneckerFactor(DenseSquareMatrixFactor):
+ r"""Kronecker factor for the input side of a convolutional layer.
+
+ Estimates E[ a a^T ] where a is the inputs to a convolutional layer given
+ example x. Expectation is taken over all examples and locations.
+
+ Equivalent to Omega in https://arxiv.org/abs/1602.01407 for details. See
+ Section 3.1 Estimating the factors.
+ """
+
+ def __init__(self,
+ inputs,
+ filter_shape,
+ padding,
+ strides=None,
+ dilation_rate=None,
+ data_format=None,
+ extract_patches_fn=None,
+ has_bias=False,
+ sub_sample_inputs=None,
+ sub_sample_patches=None):
+ """Initializes ConvInputKroneckerFactor.
+
+ Args:
+ inputs: List of Tensors of shape [batch_size, ..spatial_input_size..,
+ in_channels]. Inputs to layer. List index is tower.
+ filter_shape: List of ints. Contains [..spatial_filter_size..,
+ in_channels, out_channels]. Shape of convolution kernel.
+ padding: str. Padding method for layer. "SAME" or "VALID".
+ strides: List of ints or None. Contains [..spatial_filter_strides..] if
+ 'extract_patches_fn' is compatible with tf.nn.convolution(), else
+ [1, ..spatial_filter_strides, 1].
+ dilation_rate: List of ints or None. Rate for dilation along each spatial
+ dimension if 'extract_patches_fn' is compatible with
+ tf.nn.convolution(), else [1, ..spatial_dilation_rates.., 1].
+ data_format: str or None. Format of input data.
+ extract_patches_fn: str or None. Name of function that extracts image
+ patches. One of "extract_convolution_patches", "extract_image_patches",
+ "extract_pointwise_conv2d_patches".
+ has_bias: bool. If True, append 1 to in_channel.
+ sub_sample_inputs: `bool`. If True, then subsample the inputs from which
+ the image patches are extracted. (Default: None)
+ sub_sample_patches: `bool`, If `True` then subsample the extracted
+ patches.(Default: None)
+ """
+ self._inputs = inputs
+ self._filter_shape = filter_shape
+ self._strides = strides
+ self._padding = padding
+ self._dilation_rate = dilation_rate
+ self._data_format = data_format
+ self._extract_patches_fn = extract_patches_fn
+ self._has_bias = has_bias
+ if sub_sample_inputs is None:
+ self._sub_sample_inputs = _SUB_SAMPLE_INPUTS
+ else:
+ self._sub_sample_inputs = sub_sample_inputs
+
+ if sub_sample_patches is None:
+ self._sub_sample_patches = _SUB_SAMPLE_OUTER_PRODUCTS
+ else:
+ self._sub_sample_patches = sub_sample_patches
+ super(ConvInputKroneckerFactor, self).__init__()
+
+ @property
+ def _var_scope(self):
+ return "ff_convinkron_" + scope_string_from_params(
+ tuple(self._inputs) +
+ tuple((self._filter_shape, self._strides, self._padding,
+ self._dilation_rate, self._data_format, self._has_bias)))
+
+ @property
+ def _cov_shape(self):
+ spatial_filter_shape = self._filter_shape[0:-2]
+ in_channels = self._filter_shape[-2]
+ size = np.prod(spatial_filter_shape) * in_channels + self._has_bias
+ return [size, size]
+
+ @property
+ def _num_sources(self):
+ return 1
+
+ @property
+ def _num_towers(self):
+ return len(self._inputs)
+
+ @property
+ def _dtype(self):
+ return self._inputs[0].dtype
+
+ def _compute_new_cov(self, source, tower):
+ assert source == 0
+
+ inputs = self._inputs[tower]
+ if self._sub_sample_inputs:
+ batch_size = inputs.shape.as_list()[0]
+ max_size = int(batch_size * _INPUTS_TO_EXTRACT_PATCHES_FACTOR)
+ inputs = _random_tensor_gather(inputs, max_size)
+
+ # TODO(b/64144716): there is potential here for a big savings in terms of
+ # memory use.
+ if self._extract_patches_fn in [None, "extract_convolution_patches"]:
+ patches = utils.extract_convolution_patches(
+ inputs,
+ self._filter_shape,
+ padding=self._padding,
+ strides=self._strides,
+ dilation_rate=self._dilation_rate,
+ data_format=self._data_format)
+
+ elif self._extract_patches_fn == "extract_image_patches":
+ assert inputs.shape.ndims == 4
+ assert len(self._filter_shape) == 4
+ assert len(self._strides) == 4, self._strides
+ if self._dilation_rate is None:
+ rates = [1, 1, 1, 1]
+ else:
+ rates = self._dilation_rate
+ assert len(rates) == 4
+ assert rates[0] == rates[-1] == 1
+ patches = array_ops.extract_image_patches(
+ inputs,
+ ksizes=[1] + list(self._filter_shape[0:-2]) + [1],
+ strides=self._strides,
+ rates=rates,
+ padding=self._padding)
+
+ elif self._extract_patches_fn == "extract_pointwise_conv2d_patches":
+ assert self._strides in [None, [1, 1, 1, 1], (1, 1, 1, 1)]
+ assert self._filter_shape[0] == self._filter_shape[1] == 1
+ patches = utils.extract_pointwise_conv2d_patches(
+ inputs, self._filter_shape, data_format=None)
+
+ else:
+ raise NotImplementedError(self._extract_patches_fn)
+
+ flatten_size = np.prod(self._filter_shape[0:-1])
+ # patches_flat below is the matrix [[A_l]] from the KFC paper (tilde
+ # omitted over A for clarity). It has shape M|T| x J|Delta| (eq. 14),
+ # where M = minibatch size, |T| = number of spatial locations,
+ # |Delta| = number of spatial offsets, and J = number of input maps
+ # for convolutional layer l.
+ patches_flat = array_ops.reshape(patches, [-1, flatten_size])
+
+ # We append a homogenous coordinate to patches_flat if the layer has
+ # bias parameters. This gives us [[A_l]]_H from the paper.
+ if self._sub_sample_patches:
+ patches_flat = _subsample_for_cov_computation(patches_flat)
+
+ if self._has_bias:
+ patches_flat = append_homog(patches_flat)
+ # We call compute_cov without passing in a normalizer. compute_cov uses
+ # the first dimension of patches_flat i.e. M|T| as the normalizer by
+ # default. Hence we end up computing 1/M|T| * [[A_l]]^T [[A_l]], with
+ # shape J|Delta| x J|Delta|. This is related to hat{Omega}_l from
+ # the paper but has a different scale here for consistency with
+ # ConvOutputKroneckerFactor.
+ # (Tilde omitted over A for clarity.)
+ return compute_cov(patches_flat)
+
+ def _get_data_device(self, tower):
+ return self._inputs[tower].device
+
+
+class ConvOutputKroneckerFactor(DenseSquareMatrixFactor):
+ r"""Kronecker factor for the output side of a convolutional layer.
+
+ Estimates E[ ds ds^T ] where s is the preactivations of a convolutional layer
+ given example x and ds = (d / d s) log(p(y|x, w)). Expectation is taken over
+ all examples and locations.
+
+ Equivalent to Gamma in https://arxiv.org/abs/1602.01407 for details. See
+ Section 3.1 Estimating the factors.
+ """
+
+ def __init__(self, outputs_grads, data_format=None):
+ """Initializes ConvOutputKroneckerFactor.
+
+ Args:
+ outputs_grads: List of list of Tensors. Each Tensor is of shape
+ [batch_size, ..spatial_input_size.., out_channels]. First list index
+ is source, the second is tower.
+ data_format: None or str. Format of outputs_grads.
+
+ Raises:
+ ValueError: If channels are not final dimension.
+ """
+ if not utils.is_data_format_channel_last(data_format):
+ raise ValueError("Channel must be last.")
+ self._out_channels = outputs_grads[0][0].shape.as_list()[-1]
+ self._outputs_grads = outputs_grads
+ super(ConvOutputKroneckerFactor, self).__init__()
+
+ @property
+ def _var_scope(self):
+ return "ff_convoutkron_" + scope_string_from_params(
+ nest.flatten(self._outputs_grads))
+
+ @property
+ def _cov_shape(self):
+ size = self._out_channels
+ return [size, size]
+
+ @property
+ def _num_sources(self):
+ return len(self._outputs_grads)
+
+ @property
+ def _num_towers(self):
+ return len(self._outputs_grads[0])
+
+ @property
+ def _dtype(self):
+ return self._outputs_grads[0][0].dtype
+
+ def _compute_new_cov(self, source, tower):
+ outputs_grad = self._outputs_grads[source][tower]
+
+ # reshaped_tensor below is the matrix DS_l defined in the KFC paper
+ # (tilde omitted over S for clarity). It has shape M|T| x I, where
+ # M = minibatch size, |T| = number of spatial locations, and
+ # I = number of output maps for convolutional layer l.
+ reshaped_tensor = array_ops.reshape(outputs_grad, [-1, self._out_channels])
+ # Following the reasoning in ConvInputKroneckerFactor._compute_new_cov,
+ # compute_cov here returns 1/M|T| * DS_l^T DS_l = hat{Gamma}_l
+ # as defined in the paper, with shape I x I.
+ # (Tilde omitted over S for clarity.)
+ return compute_cov(reshaped_tensor)
+
+ def _get_data_device(self, tower):
+ return self._outputs_grads[0][tower].device
+
+
+class FullyConnectedMultiKF(FullyConnectedKroneckerFactor):
+ """Kronecker factor for a fully connected layer used multiple times."""
+
+ def __init__(self,
+ tensors,
+ num_uses=None,
+ has_bias=False):
+ """Constructs a new `FullyConnectedMultiKF`.
+
+ Args:
+ tensors: List of list of Tensors of shape, each of shape
+ [num_uses * batch_size, n], and is a reshape version of a Tensor of
+ shape [num_uses, batch_size, n]. Each of these tensors is usually a
+ layer's inputs or its output's gradients. The first list index is
+ sources, the second is towers.
+ num_uses: int. The number of time-steps / uses.
+ has_bias: bool. If True, '1' is appended to each row.
+ """
+
+ self._num_uses = num_uses
+
+ self._cov_dt1 = None
+ self._make_cov_dt1 = False
+ self._option1quants_by_damping = {}
+ self._option2quants_by_damping = {}
+ self._option1quants_registrations = set()
+ self._option2quants_registrations = set()
+
+ super(FullyConnectedMultiKF, self).__init__(tensors=tensors,
+ has_bias=has_bias)
+
+ @property
+ def _num_timesteps(self):
+ return self._num_uses
+
+ @property
+ def _var_scope(self):
+ return "ff_fc_multi_" + scope_string_from_params(
+ tuple(nest.flatten(self._tensors))
+ + (self._num_timesteps, self._has_bias,))
+
+ def make_covariance_update_op(self, ema_decay):
+
+ op = super(FullyConnectedMultiKF, self).make_covariance_update_op(ema_decay)
+
+ if self._cov_dt1 is not None:
+ new_cov_dt1_contribs = []
+ for source in range(self._num_sources):
+ for tower in range(self._num_towers):
+ with place_on_device(self._get_data_device(tower)):
+ new_cov_dt1_contribs.append(self._compute_new_cov_dt1(source,
+ tower))
+
+ new_cov_dt1 = (math_ops.add_n(new_cov_dt1_contribs)
+ / float(self._num_towers))
+
+ # See comments in FisherFactor.make_covariance_update_op() for details.
+ if utils.on_tpu():
+ new_cov_dt1 = utils.cross_replica_mean(new_cov_dt1)
+
+ op2 = moving_averages.assign_moving_average(
+ self._cov_dt1, new_cov_dt1, ema_decay, zero_debias=ZERO_DEBIAS)
+
+ # TODO(b/69112164):
+ # It's important that _cov and _cov_dt1 remain consistent with each
+ # other while the inverse ops are happening. How can we ensure this?
+ # We will need to add explicit synchronization for this to
+ # work with asynchronous training.
+ op = control_flow_ops.group(op, op2)
+
+ return op
+
+ def _compute_new_cov_dt1(self, source, tower): # pylint: disable=missing-docstring
+ tensor = self._tensors[source][tower]
+ if self._has_bias:
+ # This appending is technically done twice (the other time is for
+ # _compute_new_cov())
+ tensor = append_homog(tensor)
+
+ total_len = array_ops.shape(tensor)[0]
+ batch_size = total_len // self._num_timesteps
+
+ tensor_present = tensor[:-batch_size, :]
+ tensor_future = tensor[batch_size:, :]
+
+ # We specify a normalizer for this computation to ensure a PSD Fisher
+ # block estimate. This is equivalent to padding with zeros, as was done
+ # in Section B.2 of the appendix.
+ return compute_cov(
+ tensor_future, tensor_right=tensor_present, normalizer=total_len)
+
+ def _get_data_device(self, tower):
+ return self._tensors[0][tower].device
+
+ @property
+ def _vec_shape(self):
+ size = self._tensors[0][0].shape[1] + self._has_bias
+ return [size]
+
+ def get_option1quants(self, damping_func):
+ damping_id = graph_func_to_id(damping_func)
+ return self._option1quants_by_damping[damping_id]
+
+ def get_option2quants(self, damping_func):
+ damping_id = graph_func_to_id(damping_func)
+ return self._option2quants_by_damping[damping_id]
+
+ def get_cov_dt1(self):
+ assert self._cov_dt1 is not None
+ return self._cov_dt1
+
+ def register_cov_dt1(self):
+ self._make_cov_dt1 = True
+
+ def instantiate_cov_variables(self):
+ super(FullyConnectedMultiKF, self).instantiate_cov_variables()
+ assert self._cov_dt1 is None
+ if self._make_cov_dt1:
+ with variable_scope.variable_scope(self._var_scope):
+ self._cov_dt1 = variable_scope.get_variable(
+ "cov_dt1",
+ initializer=init_ops.zeros_initializer,
+ shape=self._cov_shape,
+ trainable=False,
+ dtype=self._dtype)
+
+ def register_option1quants(self, damping_func):
+ damping_id = self._register_damping(damping_func)
+ if damping_id not in self._option1quants_registrations:
+ self._option1quants_registrations.add(damping_id)
+
+ def register_option2quants(self, damping_func):
+ damping_id = self._register_damping(damping_func)
+ if damping_id not in self._option2quants_registrations:
+ self._option2quants_registrations.add(damping_id)
+
+ def instantiate_inv_variables(self):
+ super(FullyConnectedMultiKF, self).instantiate_inv_variables()
+
+ for damping_id in self._option1quants_registrations:
+ damping_func = self._damping_funcs_by_id[damping_id]
+ damping_string = graph_func_to_string(damping_func)
+ # It's questionable as to whether we should initialize with stuff like
+ # this at all. Ideally these values should never be used until they are
+ # updated at least once.
+ with variable_scope.variable_scope(self._var_scope):
+ Lmat = variable_scope.get_variable( # pylint: disable=invalid-name
+ "Lmat_damp{}".format(damping_string),
+ initializer=inverse_initializer,
+ shape=self._cov_shape,
+ trainable=False,
+ dtype=self._dtype)
+ psi = variable_scope.get_variable(
+ "psi_damp{}".format(damping_string),
+ initializer=init_ops.ones_initializer,
+ shape=self._vec_shape,
+ trainable=False,
+ dtype=self._dtype)
+
+ assert damping_id not in self._option1quants_by_damping
+ self._option1quants_by_damping[damping_id] = (Lmat, psi)
+
+ for damping_id in self._option2quants_registrations:
+ damping_func = self._damping_funcs_by_id[damping_id]
+ damping_string = graph_func_to_string(damping_func)
+ # It's questionable as to whether we should initialize with stuff like
+ # this at all. Ideally these values should never be used until they are
+ # updated at least once.
+ with variable_scope.variable_scope(self._var_scope):
+ Pmat = variable_scope.get_variable( # pylint: disable=invalid-name
+ "Lmat_damp{}".format(damping_string),
+ initializer=inverse_initializer,
+ shape=self._cov_shape,
+ trainable=False,
+ dtype=self._dtype)
+ Kmat = variable_scope.get_variable( # pylint: disable=invalid-name
+ "Kmat_damp{}".format(damping_string),
+ initializer=inverse_initializer,
+ shape=self._cov_shape,
+ trainable=False,
+ dtype=self._dtype)
+ mu = variable_scope.get_variable(
+ "mu_damp{}".format(damping_string),
+ initializer=init_ops.ones_initializer,
+ shape=self._vec_shape,
+ trainable=False,
+ dtype=self._dtype)
+
+ assert damping_id not in self._option2quants_by_damping
+ self._option2quants_by_damping[damping_id] = (Pmat, Kmat, mu)
+
+ def make_inverse_update_ops(self):
+ """Create and return update ops corresponding to registered computations."""
+ # TODO(b/69918258): Add correctness tests for this method.
+ # pylint: disable=invalid-name
+
+ ops = []
+
+ if (len(self._option1quants_by_damping) +
+ len(self._option2quants_by_damping)):
+
+ # Note that C0 and C1 are stand-ins for A0 and A1, or G0 and G1, from
+ # the pseudo-code in the original paper. Because the computations for
+ # the A and G case are essentially the same they can both be performed by
+ # the same class (this one).
+
+ C1 = self.get_cov_dt1()
+
+ # Get the eigendecomposition of C0 (= self.get_cov())
+ eigen_e, eigen_V = self.get_eigendecomp()
+
+ # TODO(b/69678661): Note, there is an implicit assumption here that C1
+ # and C0 (as represented here by its eigen-decomp) are consistent. This
+ # could fail to be the case if self._cov and self._cov_dt1 are not updated
+ # consistently, or are somehow read between or during the cov updates.
+ # Can this possibly happen? Is there a way to prevent it?
+
+ for damping_id, (Lmat_var,
+ psi_var) in self._option1quants_by_damping.items():
+
+ damping = self._damping_funcs_by_id[damping_id]()
+ damping = math_ops.cast(damping, self._dtype)
+
+ invsqrtC0 = math_ops.matmul(
+ eigen_V * (eigen_e + damping)**(-0.5), eigen_V, transpose_b=True)
+
+ # Might need to enforce symmetry lost due to numerical issues.
+ invsqrtC0 = (invsqrtC0 + array_ops.transpose(invsqrtC0)) / 2.0
+
+ # The following line imposes the symmetry assumed by "Option 1" on C1.
+ # Strangely the code can work okay with this line commented out,
+ # depending on how psd_eig is defined. I'm not sure why.
+ C1 = (C1 + array_ops.transpose(C1)) / 2.0
+
+ # hPsi = C0^(-1/2) * C1 * C0^(-1/2) (hPsi means hat{Psi})
+ hPsi = math_ops.matmul(math_ops.matmul(invsqrtC0, C1), invsqrtC0)
+
+ # Compute the decomposition U*diag(psi)*U^T = hPsi
+ psi, U = utils.posdef_eig(hPsi)
+
+ # L = C0^(-1/2) * U
+ Lmat = math_ops.matmul(invsqrtC0, U)
+
+ ops.append(Lmat_var.assign(Lmat))
+ ops.append(psi_var.assign(psi))
+
+ for damping_id, (Pmat_var, Kmat_var,
+ mu_var) in self._option2quants_by_damping.items():
+
+ damping = self._damping_funcs_by_id[damping_id]()
+ damping = math_ops.cast(damping, self._dtype)
+
+ # compute C0^(-1/2)
+ invsqrtC0 = math_ops.matmul(
+ eigen_V * (eigen_e + damping)**(-0.5), eigen_V, transpose_b=True)
+
+ # Might need to enforce symmetry lost due to numerical issues.
+ invsqrtC0 = (invsqrtC0 + array_ops.transpose(invsqrtC0)) / 2.0
+
+ # Compute the product C0^(-1/2) * C1
+ invsqrtC0C1 = math_ops.matmul(invsqrtC0, C1)
+
+ # hPsi = C0^(-1/2) * C1 * C0^(-1/2) (hPsi means hat{Psi})
+ hPsi = math_ops.matmul(invsqrtC0C1, invsqrtC0)
+
+ # Compute the decomposition E*diag(mu)*E^T = hPsi^T * hPsi
+ # Note that we using the notation mu instead of "m" for the eigenvalues.
+ # Instead of computing the product hPsi^T * hPsi and then doing an
+ # eigen-decomposition of this we just compute the SVD of hPsi and then
+ # square the singular values to get the eigenvalues. For a justification
+ # of this approach, see:
+ # https://en.wikipedia.org/wiki/Singular-value_decomposition#Relation_to_eigenvalue_decomposition
+ sqrtmu, _, E = linalg_ops.svd(hPsi)
+ mu = math_ops.square(sqrtmu)
+
+ # Mathematically, the eigenvalues should not should not exceed 1.0, but
+ # due to numerical issues, or possible issues with inconsistent
+ # values of C1 and (the eigen-decomposition of) C0 they might. So
+ # we enforce this condition.
+ mu = math_ops.minimum(mu, 1.0)
+
+ # P = (C0^(-1/2) * C1)^T * C0^(-1/2) = C_1^T * C_0^(-1)
+ Pmat = math_ops.matmul(invsqrtC0C1, invsqrtC0, transpose_a=True)
+
+ # K = C_0^(-1/2) * E
+ Kmat = math_ops.matmul(invsqrtC0, E)
+
+ ops.append(Pmat_var.assign(Pmat))
+ ops.append(Kmat_var.assign(Kmat))
+ ops.append(mu_var.assign(mu))
+
+ ops += super(FullyConnectedMultiKF, self).make_inverse_update_ops()
+ return [control_flow_ops.group(*ops)]
+
+ # pylint: enable=invalid-name
diff --git a/tensorflow/contrib/kfac/python/ops/fisher_factors_lib.py b/tensorflow/contrib/kfac/python/ops/fisher_factors_lib.py
new file mode 100644
index 0000000000..2d8e378a93
--- /dev/null
+++ b/tensorflow/contrib/kfac/python/ops/fisher_factors_lib.py
@@ -0,0 +1,38 @@
+# Copyright 2017 The TensorFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+"""FisherFactor definitions."""
+
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+# pylint: disable=unused-import,line-too-long,wildcard-import
+from tensorflow.contrib.kfac.python.ops.fisher_factors import *
+from tensorflow.python.util.all_util import remove_undocumented
+# pylint: enable=unused-import,line-too-long,wildcard-import
+
+_allowed_symbols = [
+ "inverse_initializer", "covariance_initializer",
+ "diagonal_covariance_initializer", "scope_string_from_params",
+ "scope_string_from_name", "scalar_or_tensor_to_string", "FisherFactor",
+ "InverseProvidingFactor", "FullFactor", "DiagonalFactor",
+ "NaiveDiagonalFactor", "EmbeddingInputKroneckerFactor",
+ "FullyConnectedDiagonalFactor", "FullyConnectedKroneckerFactor",
+ "ConvInputKroneckerFactor", "ConvOutputKroneckerFactor",
+ "ConvDiagonalFactor", "set_global_constants", "maybe_colocate_with",
+ "compute_cov", "append_homog"
+]
+
+remove_undocumented(__name__, allowed_exception_list=_allowed_symbols)
diff --git a/tensorflow/contrib/kfac/python/ops/layer_collection.py b/tensorflow/contrib/kfac/python/ops/layer_collection.py
new file mode 100644
index 0000000000..43aa713edc
--- /dev/null
+++ b/tensorflow/contrib/kfac/python/ops/layer_collection.py
@@ -0,0 +1,1269 @@
+# Copyright 2017 The TensorFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+"""Registry for layers and their parameters/variables.
+
+This represents the collection of all layers in the approximate Fisher
+information matrix to which a particular FisherBlock may belong. That is, we
+might have several layer collections for one TF graph (if we have multiple K-FAC
+optimizers being used, for example.)
+"""
+
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+from collections import defaultdict
+from collections import OrderedDict
+from contextlib import contextmanager
+from functools import partial
+import warnings
+
+import math
+import six
+
+from tensorflow.contrib.kfac.python.ops import fisher_blocks as fb
+from tensorflow.contrib.kfac.python.ops import loss_functions as lf
+from tensorflow.contrib.kfac.python.ops import utils
+from tensorflow.python.framework import ops
+from tensorflow.python.ops import math_ops
+from tensorflow.python.ops import variable_scope
+from tensorflow.python.util import nest
+
+# Names for various approximations that can be requested for Fisher blocks.
+APPROX_KRONECKER_NAME = "kron"
+APPROX_DIAGONAL_NAME = "diagonal"
+APPROX_FULL_NAME = "full"
+
+_GENERIC_APPROX_TO_BLOCK_TYPES = {
+ APPROX_FULL_NAME: fb.FullFB,
+ APPROX_DIAGONAL_NAME: fb.NaiveDiagonalFB,
+}
+
+_FULLY_CONNECTED_APPROX_TO_BLOCK_TYPES = {
+ APPROX_KRONECKER_NAME: fb.FullyConnectedKFACBasicFB,
+ APPROX_DIAGONAL_NAME: fb.FullyConnectedDiagonalFB,
+}
+
+_CONV2D_APPROX_TO_BLOCK_TYPES = {
+ APPROX_KRONECKER_NAME: fb.ConvKFCBasicFB,
+ APPROX_DIAGONAL_NAME: fb.ConvDiagonalFB,
+}
+
+_EMBEDDING_APPROX_TO_BLOCK_TYPES = {
+ APPROX_KRONECKER_NAME: fb.EmbeddingKFACFB
+}
+
+APPROX_KRONECKER_INDEP_NAME = "kron_indep"
+APPROX_KRONECKER_SERIES_1_NAME = "kron_series_1"
+APPROX_KRONECKER_SERIES_2_NAME = "kron_series_2"
+
+_FULLY_CONNECTED_MULTI_APPROX_TO_BLOCK_TYPES = {
+ APPROX_KRONECKER_INDEP_NAME: fb.FullyConnectedMultiIndepFB,
+ APPROX_KRONECKER_SERIES_1_NAME: partial(fb.FullyConnectedSeriesFB,
+ option=1),
+ APPROX_KRONECKER_SERIES_2_NAME: partial(fb.FullyConnectedSeriesFB,
+ option=2)
+}
+
+_CONV2D_MULTI_APPROX_TO_BLOCK_TYPES = {
+ APPROX_KRONECKER_INDEP_NAME: fb.ConvKFCBasicMultiIndepFB
+}
+
+_EMBEDDING_MULTI_APPROX_TO_BLOCK_TYPES = {
+ APPROX_KRONECKER_INDEP_NAME: fb.EmbeddingKFACMultiIndepFB
+}
+
+# Possible value for `reuse` keyword argument. Sets `reuse` to
+# tf.get_variable_scope().reuse.
+VARIABLE_SCOPE = "VARIABLE_SCOPE"
+
+_DEFAULT_LAYER_COLLECTION = None
+
+
+def get_default_layer_collection():
+ """Get default LayerCollection."""
+ if _DEFAULT_LAYER_COLLECTION is None:
+ raise ValueError(
+ "Attempted to retrieve default LayerCollection when none is set. Use "
+ "LayerCollection.as_default().")
+
+ return _DEFAULT_LAYER_COLLECTION
+
+
+def set_default_layer_collection(layer_collection):
+ global _DEFAULT_LAYER_COLLECTION
+
+ if _DEFAULT_LAYER_COLLECTION is not None and layer_collection is not None:
+ raise ValueError("Default LayerCollection is already set.")
+
+ _DEFAULT_LAYER_COLLECTION = layer_collection
+
+
+class LayerParametersDict(OrderedDict):
+ """An OrderedDict where keys are Tensors or tuples of Tensors.
+
+ Ensures that no Tensor is associated with two different keys.
+ """
+
+ def __init__(self, *args, **kwargs):
+ self._tensors = set()
+ super(LayerParametersDict, self).__init__(*args, **kwargs)
+
+ def __setitem__(self, key, value):
+ key = self._canonicalize_key(key)
+ tensors = key if isinstance(key, (tuple, list)) else (key,)
+ key_collisions = self._tensors.intersection(tensors)
+ if key_collisions:
+ raise ValueError("Key(s) already present: {}".format(key_collisions))
+ self._tensors.update(tensors)
+ super(LayerParametersDict, self).__setitem__(key, value)
+
+ def __delitem__(self, key):
+ key = self._canonicalize_key(key)
+ self._tensors.remove(key)
+ super(LayerParametersDict, self).__delitem__(key)
+
+ def __getitem__(self, key):
+ key = self._canonicalize_key(key)
+ return super(LayerParametersDict, self).__getitem__(key)
+
+ def __contains__(self, key):
+ key = self._canonicalize_key(key)
+ return super(LayerParametersDict, self).__contains__(key)
+
+ def _canonicalize_key(self, key):
+ if isinstance(key, (list, tuple)):
+ return tuple(key)
+ return key
+
+
+# TODO(b/68034464): add capability for LayerCollection to be "finalized"
+# and do this when it gets used by FisherEstimator / KfacOptimizer.
+
+
+class LayerCollection(object):
+ """Registry of information about layers and losses.
+
+ Note that you need to create a new one of these for each MatrixEstimator or
+ KfacOptimizer.
+
+ Attributes:
+ fisher_blocks: a LayersParamsDict (subclass of OrderedDict) mapping layer
+ parameters (Tensors or tuples of Tensors) to FisherBlock instances.
+ fisher_factors: an OrderedDict mapping tuples to FisherFactor instances.
+ losses: a list of LossFunction objects. The loss to be optimized is their
+ sum.
+ loss_colocation_ops: ops to colocate loss function evaluations with. These
+ will typically be the inputs to the losses.
+ """
+
+ def __init__(self,
+ graph=None,
+ name="LayerCollection"):
+ warnings.warn(
+ "tf.contrib.kfac is deprecated and will be removed by 2018-11-01. "
+ "Use https://pypi.python.org/pypi/kfac instead.")
+ self.fisher_blocks = LayerParametersDict()
+ self.fisher_factors = OrderedDict()
+ self._linked_parameters = dict(
+ ) # dict mapping sets of variables to optionally specified approximations.
+ self._graph = graph or ops.get_default_graph()
+ self._loss_dict = {} # {str: LossFunction}
+ self._subgraph = None
+ self._default_generic_approximation = APPROX_DIAGONAL_NAME
+ self._default_embedding_approximation = APPROX_KRONECKER_NAME
+ self._default_fully_connected_approximation = APPROX_KRONECKER_NAME
+ self._default_conv2d_approximation = APPROX_KRONECKER_NAME
+ self._default_fully_connected_multi_approximation = (
+ APPROX_KRONECKER_INDEP_NAME)
+ self._default_conv2d_multi_approximation = (
+ APPROX_KRONECKER_INDEP_NAME)
+ self._default_embedding_multi_approximation = APPROX_KRONECKER_INDEP_NAME
+ self.loss_colocation_ops = {}
+ self._vars_to_uses = defaultdict(lambda: 0)
+
+ with variable_scope.variable_scope(None, default_name=name) as scope:
+ self._var_scope = scope.name
+
+ @property
+ def losses(self):
+ """Tuple of LossFunction objects registered with this LayerCollection."""
+ return nest.flatten(self.towers_by_loss)
+
+ @property
+ def towers_by_loss(self):
+ """Tuple across losses of LossFunction objects registered to each tower."""
+ return tuple(tuple(lst) for lst in self._loss_dict.values())
+
+ @property
+ def registered_variables(self):
+ """A tuple of all of the variables currently registered."""
+ tuple_of_tuples = (utils.ensure_sequence(key) for key, block
+ in six.iteritems(self.fisher_blocks))
+ flat_tuple = tuple(item for tuple_ in tuple_of_tuples for item in tuple_)
+ return flat_tuple
+
+ @property
+ def linked_parameters(self):
+ """Groups of parameters with an optionally specified approximation.
+
+ Linked parameters can be added using `define_linked_parameters`.
+ If an approximation is specified, then this approximation will be used
+ when registering a layer with exactly these parameters, unless an
+ approximation is specified when calling the registration function.
+
+ Returns:
+ A `dict` mapping tuples of parameters to an optional string.
+ """
+ return self._linked_parameters
+
+ @property
+ def default_embedding_approximation(self):
+ return self._default_embedding_approximation
+
+ def set_default_embedding_approximation(self, value):
+ if value != APPROX_KRONECKER_NAME:
+ raise ValueError(
+ "{} is not a valid approximation for embedding variables.".format(
+ value))
+ self._default_embedding_approximation = value
+
+ @property
+ def default_generic_approximation(self):
+ return self._default_generic_approximation
+
+ def set_default_generic_approximation(self, value):
+ if value not in _GENERIC_APPROX_TO_BLOCK_TYPES:
+ raise ValueError(
+ "{} is not a valid approximation for generic variables.".format(
+ value))
+ self._default_generic_approximation = value
+
+ @property
+ def default_fully_connected_approximation(self):
+ return self._default_fully_connected_approximation
+
+ def set_default_fully_connected_approximation(self, value):
+ if value not in _FULLY_CONNECTED_APPROX_TO_BLOCK_TYPES:
+ raise ValueError(
+ "{} is not a valid approximation for fully connected layers.".format(
+ value))
+ self._default_fully_connected_approximation = value
+
+ @property
+ def default_conv2d_approximation(self):
+ return self._default_conv2d_approximation
+
+ def set_default_conv2d_approximation(self, value):
+ if value not in _CONV2D_APPROX_TO_BLOCK_TYPES:
+ raise ValueError(
+ "{} is not a valid approximation for 2d convolutional layers.".format(
+ value))
+ self._default_conv2d_approximation = value
+
+ @property
+ def default_fully_connected_multi_approximation(self):
+ return self._default_fully_connected_multi_approximation
+
+ def set_default_fully_connected_multi_approximation(self, value):
+ if value not in _FULLY_CONNECTED_MULTI_APPROX_TO_BLOCK_TYPES:
+ raise ValueError("{} is not a valid approximation for a fully-connected "
+ "multi layer.".format(value))
+ self._default_fully_connected_multi_approximation = value
+
+ @property
+ def default_conv2d_multi_approximation(self):
+ return self._default_conv2d_multi_approximation
+
+ @property
+ def default_embedding_multi_approximation(self):
+ return self._default_embedding_multi_approximation
+
+ def register_block(self, layer_key, fisher_block, reuse=VARIABLE_SCOPE):
+ """Validates and registers the layer_key associated with the fisher_block.
+
+ Args:
+ layer_key: A variable or tuple of variables. The key to check for in
+ existing registrations and to register if valid.
+ fisher_block: The associated `FisherBlock`.
+ reuse: Method to use for inserting new `FisherBlock's. One of True, False,
+ or `VARIABLE_SCOPE`.
+
+ Raises:
+ ValueError: If `layer_key` was already registered and reuse is `False`,
+ if `layer_key` was registered with a different block type, or if
+ `layer_key` shares any variables with but is not equal to a previously
+ registered key.
+ KeyError: If `reuse` is `True` but `layer_key` was not previously
+ registered.
+
+ Returns:
+ The `FisherBlock` registered under `layer_key`. If `layer_key` was already
+ registered, this will be the previously registered `FisherBlock`.
+ """
+ if reuse is VARIABLE_SCOPE:
+ reuse = variable_scope.get_variable_scope().reuse
+
+ if reuse is True or (reuse is variable_scope.AUTO_REUSE and
+ layer_key in self.fisher_blocks):
+ result = self.fisher_blocks[layer_key]
+ if type(result) != type(fisher_block): # pylint: disable=unidiomatic-typecheck
+ raise ValueError(
+ "Attempted to register FisherBlock of type %s when existing "
+ "FisherBlock has type %s." % (type(fisher_block), type(result)))
+ return result
+ if reuse is False and layer_key in self.fisher_blocks:
+ raise ValueError("FisherBlock for %s is already in LayerCollection." %
+ (layer_key,))
+
+ # Insert fisher_block into self.fisher_blocks.
+ if layer_key in self.fisher_blocks:
+ raise ValueError("Duplicate registration: {}".format(layer_key))
+ # Raise an error if any variable in layer_key has been registered in any
+ # other blocks.
+ variable_to_block = {
+ var: (params, block)
+ for (params, block) in self.fisher_blocks.items()
+ for var in utils.ensure_sequence(params)
+ }
+ for variable in utils.ensure_sequence(layer_key):
+ if variable in variable_to_block:
+ prev_key, prev_block = variable_to_block[variable]
+ raise ValueError(
+ "Attempted to register layer_key {} with block {}, but variable {}"
+ " was already registered in key {} with block {}.".format(
+ layer_key, fisher_block, variable, prev_key, prev_block))
+ self.fisher_blocks[layer_key] = fisher_block
+ return fisher_block
+
+ def register_loss_function(self,
+ loss,
+ colocation_op,
+ base_name,
+ name=None,
+ reuse=VARIABLE_SCOPE):
+ """Registers a LossFunction object.
+
+ Args:
+ loss: The LossFunction object.
+ colocation_op: The op to colocate the loss function's computations with.
+ base_name: The name to derive a new unique name from is the name argument
+ is None.
+ name: (OPTIONAL) str or None. Unique name for this loss function. If None,
+ a new name is generated. (Default: None)
+ reuse: (OPTIONAL) bool or str. If True, adds `loss` as an additional
+ tower for the existing loss function.
+
+ Raises:
+ ValueError: If reuse == True and name == None.
+ ValueError: If reuse == True and seed != None.
+ KeyError: If reuse == True and no existing LossFunction with `name` found.
+ KeyError: If reuse == False and existing LossFunction with `name` found.
+ """
+
+ name = name or self._graph.unique_name(base_name)
+
+ if reuse == VARIABLE_SCOPE:
+ reuse = variable_scope.get_variable_scope().reuse
+
+ if reuse:
+ if name is None:
+ raise ValueError(
+ "If reuse is enabled, loss function's name must be set.")
+
+ loss_list = self._loss_dict.get(name, None)
+
+ if loss_list is None:
+ raise KeyError(
+ "Unable to find loss function named {}. Register a new loss "
+ "function with reuse=False.".format(name))
+ else:
+ if name in self._loss_dict:
+ raise KeyError(
+ "Loss function named {} already exists. Set reuse=True to append "
+ "another tower.".format(name))
+
+ loss_list = []
+ self._loss_dict[name] = loss_list
+
+ loss_list.append(loss)
+ self.loss_colocation_ops[loss] = colocation_op
+
+ def _get_use_count_map(self):
+ """Returns a dict mapping variables to their number of registrations."""
+ return self._vars_to_uses
+
+ def _add_uses(self, params, uses):
+ """Register additional uses by params in the graph.
+
+ Args:
+ params: Variable or tuple of Variables. Parameters for a layer.
+ uses: int or float. Number of additional uses for these parameters.
+ """
+ params = params if isinstance(params, (tuple, list)) else (params,)
+ for var in params:
+ self._vars_to_uses[var] += uses
+
+ def check_registration(self, variables):
+ """Checks that all variable uses have been registered properly.
+
+ Args:
+ variables: List of variables.
+
+ Raises:
+ ValueError: If any registered variables are not included in the list.
+ ValueError: If any variable in the list is not registered.
+ ValueError: If any variable in the list is registered with the wrong
+ number of "uses" in the subgraph recorded (vs the number of times that
+ variable is actually used in the subgraph).
+ """
+ # Note that overlapping parameters (i.e. those that share variables) will
+ # be caught by layer_collection.LayerParametersDict during registration.
+
+ reg_use_map = self._get_use_count_map()
+
+ error_messages = []
+
+ for var in variables:
+ total_uses = self.subgraph.variable_uses(var)
+ reg_uses = reg_use_map[var]
+
+ if reg_uses == 0:
+ error_messages.append("Variable {} not registered.".format(var))
+ elif (not math.isinf(reg_uses)) and reg_uses != total_uses:
+ error_messages.append(
+ "Variable {} registered with wrong number of uses ({} "
+ "registrations vs {} uses).".format(var, reg_uses, total_uses))
+
+ num_get_vars = len(reg_use_map)
+
+ if num_get_vars > len(variables):
+ error_messages.append("{} registered variables were not included in list."
+ .format(num_get_vars - len(variables)))
+
+ if error_messages:
+ error_messages = [
+ "Found the following errors with variable registration:"
+ ] + error_messages
+ raise ValueError("\n\t".join(error_messages))
+
+ def get_blocks(self):
+ return self.fisher_blocks.values()
+
+ def get_factors(self):
+ return self.fisher_factors.values()
+
+ @property
+ def graph(self):
+ return self._graph
+
+ @property
+ def subgraph(self):
+ return self._subgraph
+
+ def define_linked_parameters(self, params, approximation=None):
+ """Identify a set of parameters that should be grouped together.
+
+ During automatic graph scanning, any matches containing variables that have
+ been identified as part of a linked group will be filtered out unless
+ the match parameters are exactly equal to the ones specified in the linked
+ group.
+
+ Args:
+ params: A variable, or a tuple or list of variables. The variables
+ to be linked.
+ approximation: Optional string specifying the type of approximation to use
+ for these variables. If unspecified, this layer collection's default
+ approximation for the layer type will be used.
+
+ Raises:
+ ValueError: If the parameters were already registered in a layer or
+ identified as part of an incompatible group.
+ """
+ params = frozenset(utils.ensure_sequence(params))
+
+ # Check if any of the variables in `params` is already in
+ # 'self.fisher_blocks.keys()`.
+ for registered_params, fisher_block in self.fisher_blocks.items():
+ registered_params_set = set(utils.ensure_sequence(registered_params))
+ for variable in params:
+ if (variable in registered_params_set and
+ params != registered_params_set):
+ raise ValueError(
+ "Can`t link parameters {}, variable {} was already registered in "
+ "group {} with layer {}".format(params, variable,
+ registered_params, fisher_block))
+
+ # Check if any of the variables in `params` is already in
+ # 'self.linked_parameters`.
+ for variable in params:
+ for other_linked_params in self.linked_parameters:
+ if variable in other_linked_params:
+ raise ValueError("Can`t link parameters {}, variable {} was already "
+ "linked in group {}.".format(params, variable,
+ other_linked_params))
+ self._linked_parameters[params] = approximation
+
+ def create_subgraph(self):
+ if not self.losses:
+ raise ValueError("Must have at least one registered loss.")
+ inputs_to_losses = nest.flatten(tuple(loss.inputs for loss in self.losses))
+ self._subgraph = utils.SubGraph(inputs_to_losses)
+
+ def eval_losses(self):
+ """Return evaluated losses (colocated with inputs to losses)."""
+ evals = []
+ for loss in self.losses:
+ with ops.colocate_with(self.loss_colocation_ops[loss]):
+ evals.append(loss.evaluate())
+ return evals
+
+ def eval_losses_on_samples(self):
+ """Return losses evaluated on samples (colocated with inputs to losses)."""
+ evals = []
+ for loss in self.losses:
+ with ops.colocate_with(self.loss_colocation_ops[loss]):
+ evals.append(loss.evaluate_on_sample())
+ return evals
+
+ def total_loss(self):
+ return math_ops.add_n(self.eval_losses())
+
+ def total_sampled_loss(self):
+ return math_ops.add_n(self.eval_losses_on_samples())
+
+ def _get_linked_approx(self, params):
+ """If params were linked, return their specified approximation."""
+ params_set = frozenset(utils.ensure_sequence(params))
+ if params_set in self.linked_parameters:
+ return self.linked_parameters[params_set]
+ else:
+ return None
+
+ def _get_block_type(self, params, approx, default, approx_to_type):
+ if approx is None:
+ approx = self._get_linked_approx(params)
+ if approx is None:
+ approx = default
+
+ if approx not in approx_to_type:
+ raise ValueError("Bad value {} for approx.".format(approx))
+
+ return approx_to_type[approx], approx
+
+ def register_embedding(self,
+ params,
+ inputs,
+ outputs,
+ approx=None,
+ reuse=VARIABLE_SCOPE):
+ """Registers an embedding layer.
+
+ Args:
+ params: Embedding matrix of shape [vocab_size, embedding_size].
+ inputs: Tensor of shape [batch_size, input_size] and dtype int32. Indices
+ into embedding matrix.
+ outputs: Tensor of shape [batch_size, embedding_size]. Outputs
+ produced by layer.
+ approx: str or None. If not None must be "kron". The Fisher
+ approximation to use. If None the default value is used. (Default: None)
+ reuse: bool or str. If True, this adds `inputs` and `outputs` as an
+ additional mini-batch/tower of data to use when estimating the Fisher
+ block for this layer (which must have already been registered). If
+ "VARIABLE_SCOPE", use tf.get_variable_scope().reuse.
+ (Default: "VARIABLE_SCOPE")
+
+ Raises:
+ ValueError: For improper value to `approx`.
+ KeyError: If reuse == True but no FisherBlock found for `params`.
+ ValueError: If reuse == True and FisherBlock found but of the wrong type.
+ """
+ block_type, approx = self._get_block_type(
+ params, approx, self.default_embedding_approximation,
+ _EMBEDDING_APPROX_TO_BLOCK_TYPES)
+
+ if isinstance(params, (tuple, list)):
+ raise ValueError("Bias not supported.")
+ vocab_size = int(params.shape[0])
+ block = self.register_block(
+ params, block_type(self, vocab_size), reuse=reuse)
+ block.register_additional_tower(inputs, outputs)
+
+ self._add_uses(params, 1)
+
+ def register_fully_connected(self,
+ params,
+ inputs,
+ outputs,
+ approx=None,
+ reuse=VARIABLE_SCOPE):
+ """Registers a fully connected layer.
+
+ Args:
+ params: Tensor or 2-tuple of Tensors corresponding to weight and bias of
+ this layer. Weight matrix should have shape [input_size, output_size].
+ Bias should have shape [output_size].
+ inputs: Tensor of shape [batch_size, input_size]. Inputs to layer.
+ outputs: Tensor of shape [batch_size, output_size]. Outputs
+ produced by layer.
+ approx: str or None. If not None must be one of "kron" or "diagonal".
+ The Fisher approximation to use. If None the default value is used.
+ (Default: None)
+ reuse: bool or str. If True, this adds `inputs` and `outputs` as an
+ additional mini-batch/tower of data to use when estimating the Fisher
+ block for this layer (which must have already been registered). If
+ "VARIABLE_SCOPE", use tf.get_variable_scope().reuse.
+ (Default: "VARIABLE_SCOPE")
+
+ Raises:
+ ValueError: For improper value to `approx`.
+ KeyError: If reuse == True but no FisherBlock found for `params`.
+ ValueError: If reuse == True and FisherBlock found but of the wrong type.
+ """
+
+ block_type, approx = self._get_block_type(
+ params, approx, self.default_fully_connected_approximation,
+ _FULLY_CONNECTED_APPROX_TO_BLOCK_TYPES)
+
+ has_bias = isinstance(params, (tuple, list))
+ block = self.register_block(params, block_type(self, has_bias=has_bias),
+ reuse=reuse)
+ block.register_additional_tower(inputs, outputs)
+
+ self._add_uses(params, 1)
+
+ def register_conv2d(self,
+ params,
+ strides,
+ padding,
+ inputs,
+ outputs,
+ data_format=None,
+ dilations=None,
+ approx=None,
+ reuse=VARIABLE_SCOPE):
+ """Registers a call to tf.nn.conv2d().
+
+ Args:
+ params: Tensor or 2-tuple of Tensors corresponding to weight and bias of
+ this layer. Weight matrix should have shape [kernel_height,
+ kernel_width, in_channels, out_channels]. Bias should have shape
+ [out_channels].
+ strides: List of 4 ints. Strides for convolution kernel.
+ padding: string. see tf.nn.conv2d for valid values.
+ inputs: Tensor of shape [batch_size, height, width, in_channels]. Inputs
+ to layer.
+ outputs: Tensor of shape [batch_size, height, width, out_channels].
+ Output produced by layer.
+ data_format: str or None. Format of data.
+ dilations: List of 4 ints. Dilations along each dimension.
+ approx: str or None. If not None must be one of "kron" or "diagonal".
+ The Fisher approximation to use. If None the default value is used.
+ (Default: None)
+ reuse: bool or str. If True, this adds `inputs` and `outputs` as an
+ additional mini-batch/tower of data to use when estimating the Fisher
+ block for this layer (which must have already been registered). If
+ "VARIABLE_SCOPE", use tf.get_variable_scope().reuse.
+ (Default: "VARIABLE_SCOPE")
+
+ Raises:
+ ValueError: For improper value to `approx`.
+ KeyError: If reuse == True but no FisherBlock found for `params`.
+ ValueError: If reuse == True and FisherBlock found but of the wrong type.
+ """
+
+ block_type, approx = self._get_block_type(
+ params, approx, self.default_conv2d_approximation,
+ _CONV2D_APPROX_TO_BLOCK_TYPES)
+
+ # It feels bad to pass in configuration that has to do with the internal
+ # implementation. And then we can`t use the same constructor for both
+ # anymore and are thus forced to use this ugly if-statement.
+ # TODO(b/74793309): Clean this up?
+ if approx == APPROX_KRONECKER_NAME:
+ block = self.register_block(
+ params,
+ block_type(
+ layer_collection=self,
+ params=params,
+ padding=padding,
+ strides=strides,
+ data_format=data_format,
+ dilation_rate=dilations,
+ extract_patches_fn="extract_image_patches"),
+ reuse=reuse)
+ elif approx == APPROX_DIAGONAL_NAME:
+ assert strides[0] == strides[-1] == 1
+ block = self.register_block(
+ params,
+ block_type(
+ layer_collection=self,
+ params=params,
+ padding=padding,
+ strides=strides,
+ dilations=dilations,
+ data_format=data_format),
+ reuse=reuse)
+ else:
+ raise NotImplementedError(approx)
+
+ block.register_additional_tower(inputs, outputs)
+
+ self._add_uses(params, 1)
+
+ def register_convolution(self,
+ params,
+ inputs,
+ outputs,
+ padding,
+ strides=None,
+ dilation_rate=None,
+ data_format=None,
+ approx=None,
+ reuse=VARIABLE_SCOPE):
+ """Register a call to tf.nn.convolution().
+
+ Args:
+ params: Tensor or 2-tuple of Tensors corresponding to weight and bias of
+ this layer. Weight matrix should have shape [..filter_spatial_size..,
+ in_channels, out_channels]. Bias should have shape [out_channels].
+ inputs: Tensor of shape [batch_size, ..input_spatial_size.., in_channels].
+ Inputs to layer.
+ outputs: Tensor of shape [batch_size, ..output_spatial_size..,
+ out_channels]. Output produced by layer.
+ padding: string. see tf.nn.conv2d for valid values.
+ strides: List of ints of length len(..input_spatial_size..). Strides for
+ convolution kernel in spatial dimensions.
+ dilation_rate: List of ints of length len(..input_spatial_size..).
+ Dilations along spatial dimension.
+ data_format: str or None. Format of data.
+ approx: str or None. If not None must be one of "kron" or "diagonal".
+ The Fisher approximation to use. If None the default value is used.
+ (Default: None)
+ reuse: bool or str. If True, this adds `inputs` and `outputs` as an
+ additional mini-batch/tower of data to use when estimating the Fisher
+ block for this layer (which must have already been registered). If
+ "VARIABLE_SCOPE", use tf.get_variable_scope().reuse.
+ (Default: "VARIABLE_SCOPE")
+
+ Raises:
+ ValueError: For improper value to `approx`.
+ KeyError: If reuse == True but no FisherBlock found for `params`.
+ ValueError: If reuse == True and FisherBlock found but of the wrong type.
+ """
+ # TODO(b/74793309): Have this use _get_block_type like the other
+ # registration functions?
+ assert approx is None or approx == APPROX_KRONECKER_NAME
+
+ block = self.register_block(
+ params,
+ fb.ConvKFCBasicFB(
+ layer_collection=self,
+ params=params,
+ padding=padding,
+ strides=strides,
+ dilation_rate=dilation_rate,
+ data_format=data_format),
+ reuse=reuse)
+ block.register_additional_tower(inputs, outputs)
+
+ self._add_uses(params, 1)
+
+ def register_depthwise_conv2d(self,
+ params,
+ inputs,
+ outputs,
+ strides,
+ padding,
+ rate=None,
+ data_format=None,
+ approx=None,
+ reuse=VARIABLE_SCOPE):
+ """Register a call to tf.nn.depthwise_conv2d().
+
+ Args:
+ params: 4-D Tensor of shape [filter_height, filter_width,
+ in_channels, channel_multiplier]. Convolutional filter.
+ inputs: Tensor of shape [batch_size, input_height, input_width,
+ in_channels]. Inputs to layer.
+ outputs: Tensor of shape [batch_size, output_height, output_width,
+ in_channels * channel_multiplier]. Output produced by depthwise conv2d.
+ strides: List of ints of length 4. Strides along all dimensions.
+ padding: string. see tf.nn.conv2d for valid values.
+ rate: None or List of ints of length 2. Dilation rates in spatial
+ dimensions.
+ data_format: str or None. Format of data.
+ approx: str or None. If not None must "diagonal". The Fisher
+ approximation to use. If None the default value is used. (Default: None)
+ reuse: bool or str. If True, this adds `inputs` and `outputs` as an
+ additional mini-batch/tower of data to use when estimating the Fisher
+ block for this layer (which must have already been registered). If
+ "VARIABLE_SCOPE", use tf.get_variable_scope().reuse.
+ (Default: "VARIABLE_SCOPE")
+
+ Raises:
+ ValueError: For improper value to `approx`.
+ KeyError: If reuse == True but no FisherBlock found for `params`.
+ ValueError: If reuse == True and FisherBlock found but of the wrong type.
+ """
+ # TODO(b/74793309): Have this use _get_block_type like the other
+ # registration functions?
+ assert approx is None or approx == APPROX_DIAGONAL_NAME
+ assert data_format in [None, "NHWC"]
+
+ block = self.register_block(
+ params,
+ fb.DepthwiseConvDiagonalFB(
+ layer_collection=self,
+ params=params,
+ strides=strides,
+ padding=padding,
+ rate=rate,
+ data_format=data_format),
+ reuse=reuse)
+ block.register_additional_tower(inputs, outputs)
+
+ self._add_uses(params, 1)
+
+ def register_separable_conv2d(self,
+ depthwise_params,
+ pointwise_params,
+ inputs,
+ depthwise_outputs,
+ pointwise_outputs,
+ strides,
+ padding,
+ rate=None,
+ data_format=None,
+ approx=None,
+ reuse=VARIABLE_SCOPE):
+ """Register a call to tf.nn.separable_conv2d().
+
+ Note: This requires access to intermediate outputs between depthwise and
+ pointwise convolutions.
+
+ Args:
+ depthwise_params: 4-D Tensor of shape [filter_height, filter_width,
+ in_channels, channel_multiplier]. Filter for depthwise conv2d.
+ pointwise_params: 4-D Tensor of shape [1, 1, in_channels *
+ channel_multiplier, out_channels]. Filter for pointwise conv2d.
+ inputs: Tensor of shape [batch_size, input_height, input_width,
+ in_channels]. Inputs to layer.
+ depthwise_outputs: Tensor of shape [batch_size, output_height,
+ output_width, in_channels * channel_multiplier]. Output produced by
+ depthwise conv2d.
+ pointwise_outputs: Tensor of shape [batch_size, output_height,
+ output_width, out_channels]. Output produced by pointwise conv2d.
+ strides: List of ints of length 4. Strides for depthwise conv2d kernel in
+ all dimensions.
+ padding: string. see tf.nn.conv2d for valid values.
+ rate: None or List of ints of length 2. Dilation rate of depthwise conv2d
+ kernel in spatial dimensions.
+ data_format: str or None. Format of data.
+ approx: str or None. If not None must be one of "kron" or "diagonal".
+ The Fisher approximation to use. If None the default value is used.
+ (Default: None)
+ reuse: bool or str. If True, this adds `inputs` and `outputs` as an
+ additional mini-batch/tower of data to use when estimating the Fisher
+ block for this layer (which must have already been registered). If
+ "VARIABLE_SCOPE", use tf.get_variable_scope().reuse.
+ (Default: "VARIABLE_SCOPE")
+
+ Raises:
+ ValueError: For improper value to `approx`.
+ KeyError: If reuse == True but no FisherBlock found for `params`.
+ ValueError: If reuse == True and FisherBlock found but of the wrong type.
+ """
+ self.register_depthwise_conv2d(
+ params=depthwise_params,
+ inputs=inputs,
+ outputs=depthwise_outputs,
+ strides=strides,
+ padding=padding,
+ rate=rate,
+ data_format=data_format,
+ approx=APPROX_DIAGONAL_NAME,
+ reuse=reuse)
+
+ self.register_conv2d(
+ params=pointwise_params,
+ inputs=depthwise_outputs,
+ outputs=pointwise_outputs,
+ strides=[1, 1, 1, 1],
+ padding="VALID",
+ data_format=data_format,
+ approx=approx,
+ reuse=reuse)
+
+ def register_generic(self,
+ params,
+ batch_size,
+ approx=None,
+ reuse=VARIABLE_SCOPE):
+ """Registers a generic layer.
+
+ Args:
+ params: Tensor or tuple of Tensors corresponding to the parameters.
+ batch_size: 0-D Tensor. Size of the minibatch (for this tower).
+ approx: str or None. It not None, must be one of "full" or "diagonal".
+ The Fisher approximation to use. If None the default value is used.
+ (Default: None)
+ reuse: bool or str. If True, this adds `batch_size` to the total
+ mini-batch size use when estimating the Fisher block for this layer
+ (which must have already been registered). If "VARIABLE_SCOPE", use
+ tf.get_variable_scope().reuse. (Default: "VARIABLE_SCOPE")
+
+ Raises:
+ ValueError: For improper value to `approx`.
+ KeyError: If reuse == True but no FisherBlock found for `params`.
+ ValueError: If reuse == True and FisherBlock found but of the wrong type.
+ """
+ block_type, approx = self._get_block_type(
+ params, approx, self.default_generic_approximation,
+ _GENERIC_APPROX_TO_BLOCK_TYPES)
+
+ block = self.register_block(params, block_type(self, params), reuse=reuse)
+ block.register_additional_tower(batch_size)
+
+ self._add_uses(params, float("inf"))
+
+ def register_fully_connected_multi(self, params, inputs, outputs,
+ num_uses=None, approx=None,
+ reuse=VARIABLE_SCOPE):
+ """Register fully connected layers with shared parameters.
+
+ This can handle general fully-connected layers with shared parameters, but
+ has specialized approximations to deal with the case where there is a
+ meaningful linear order to the share instances (such as in an RNN).
+
+ Args:
+ params: Tensor or 2-tuple of Tensors corresponding to weight and bias of
+ this layer. Weight matrix should have shape [input_size, output_size].
+ Bias should have shape [output_size].
+ inputs: A list of Tensors, each of shape [batch_size, input_size]. Inputs
+ to layer. The list indexes each use in the graph (which might
+ correspond to a "time-step" in an RNN). OR, can be single Tensor, of
+ shape [num_uses * batch_size , input_size], which is a reshaped version
+ of a Tensor of shape [num_uses, batch_size, input_size].
+ outputs: A list of Tensors, the same length as `inputs`, each of shape
+ [batch_size, output_size]. Outputs produced by layer. The list indexes
+ each use in the graph (which might correspond to a "time-step" in an
+ RNN). Needs to correspond with the order used in `inputs`. OR, can be
+ a single Tensor of shape [num_uses * batch_size, output_size], which is
+ a reshaped version of a Tensor of shape [num_uses, batch_size,
+ output_size].
+ num_uses: int or None. The number uses/time-steps in the graph where the
+ layer appears. Only needed if both inputs and outputs are given in the
+ single Tensor format. (Default: None)
+ approx: str or None. If not None, must be of "kron_indep", "kron_series_1"
+ or "kron_series_2". The Fisher approximation to use. If None the default
+ value is used. (Default: None)
+ reuse: bool or str. If True, this adds `inputs` and `outputs` as an
+ additional mini-batch/tower of data to use when estimating the Fisher
+ block for this layer (which must have already been registered). If
+ "VARIABLE_SCOPE", use tf.get_variable_scope().reuse. (Note that the
+ word `use` here has a completely different meaning to "use in the graph"
+ as it pertains to the `inputs`, `outputs`, and `num_uses` arguments.)
+ (Default: "VARIABLE_SCOPE")
+
+ Raises:
+ ValueError: For improper value to `approx`.
+ """
+ block_type, approx = self._get_block_type(
+ params, approx, self.default_fully_connected_multi_approximation,
+ _FULLY_CONNECTED_MULTI_APPROX_TO_BLOCK_TYPES)
+
+ # TODO(b/70283649): something along the lines of find_canonical_output
+ # should be added back in here (and for the other block types, arguably).
+
+ has_bias = isinstance(params, (tuple, list))
+ block = self.register_block(params, block_type(self, has_bias=has_bias,
+ num_uses=num_uses),
+ reuse=reuse)
+ block.register_additional_tower(inputs, outputs)
+ if isinstance(inputs, (tuple, list)):
+ assert len(inputs) == len(outputs)
+ self._add_uses(params, len(inputs))
+ else:
+ self._add_uses(params, 1)
+
+ def register_conv2d_multi(self,
+ params,
+ strides,
+ padding,
+ inputs,
+ outputs,
+ num_uses=None,
+ data_format=None,
+ dilations=None,
+ approx=None,
+ reuse=VARIABLE_SCOPE):
+ """Registers convolutional layers with shared parameters.
+
+ Args:
+ params: Tensor or 2-tuple of Tensors corresponding to weight and bias of
+ this layer. Weight matrix should have shape [kernel_height,
+ kernel_width, in_channels, out_channels]. Bias should have shape
+ [out_channels].
+ strides: 1-D Tensor of length 4. Strides for convolution kernel.
+ padding: string. see tf.nn.conv2d for valid values.
+ inputs: A list of Tensors, each of shape [batch_size, height, width,
+ in_channels]. Inputs to layer. The list indexes each use in the graph
+ (which might correspond to a "time-step" in an RNN). OR, can be single
+ Tensor, of shape [num_uses * batch_size, height, width, in_channels],
+ which is a reshaped version of a Tensor of shape [num_uses, batch_size,
+ height, width, in_channels].
+ outputs: A list of Tensors, each of shape [batch_size, height, width,
+ out_channels]. Output produced by layer. The list indexes each use
+ in the graph (which might correspond to a "time-step" in an RNN).
+ Needs to correspond with the order used in `inputs`. OR, can be a
+ single Tensor, of shape [num_uses * batch_size, height, width,
+ out_channels], which is a reshaped version of a Tensor of shape
+ [num_uses, batch_size, height, width, out_channels].
+ num_uses: int or None. The number uses/time-steps in the graph where the
+ layer appears. Only needed if both inputs and outputs are given in the
+ single Tensor format. (Default: None)
+ data_format: str or None. Format of data.
+ dilations: List of 4 ints. Dilations along each dimension.
+ approx: str or None. If not None must by "kron_indep". The Fisher
+ approximation to use. If None the default value is used.
+ (Default: None)
+ reuse: bool or str. If True, this adds `inputs` and `outputs` as an
+ additional mini-batch/tower of data to use when estimating the Fisher
+ block for this layer (which must have already been registered). If
+ "VARIABLE_SCOPE", use tf.get_variable_scope().reuse. (Note that the
+ word `use` here has a completely different meaning to "use in the graph"
+ as it pertains to the `inputs`, `outputs`, and `num_uses` arguments.)
+ (Default: "VARIABLE_SCOPE")
+
+ Raises:
+ ValueError: For improper value to `approx`.
+ KeyError: If reuse == True but no FisherBlock found for `params`.
+ ValueError: If reuse == True and FisherBlock found but of the wrong type.
+ """
+ block_type, approx = self._get_block_type(
+ params, approx, self.default_conv2d_multi_approximation,
+ _CONV2D_MULTI_APPROX_TO_BLOCK_TYPES)
+
+ block = self.register_block(
+ params,
+ block_type(
+ layer_collection=self,
+ params=params,
+ padding=padding,
+ strides=strides,
+ data_format=data_format,
+ dilation_rate=dilations,
+ extract_patches_fn="extract_image_patches",
+ num_uses=num_uses),
+ reuse=reuse)
+
+ block.register_additional_tower(inputs, outputs)
+ if isinstance(inputs, (tuple, list)):
+ assert len(inputs) == len(outputs)
+ self._add_uses(params, len(inputs))
+ else:
+ self._add_uses(params, 1)
+
+ # TODO(b/74108452): change the loss registration functions names to refer
+ # to "loss functions" instead of distributions. Following naming convention
+ # of the loss function classes themselves.
+
+ def register_embedding_multi(self,
+ params,
+ inputs,
+ outputs,
+ num_uses=None,
+ approx=None,
+ reuse=VARIABLE_SCOPE):
+ """Registers embedding layers with shared parameters.
+
+ Args:
+ params: Embedding matrix of shape [vocab_size, embedding_size].
+ inputs: A list of Tensors, each of shape [batch_size, input_size] and
+ dtype int32. Indices into embedding matrix. The list indexes each use
+ in the graph (which might correspond to a "time-step" in an RNN).
+ OR, can be single Tensor, of shape [num_uses*batch_size, input_size],
+ which is a reshaped version of a Tensor of shape [num_uses, batch_size,
+ input_size].
+ outputs: A list of Tensors, each of shape [batch_size, embedding_size].
+ Outputs produced by layer. The list indexes each use in the graph
+ (which might correspond to a "time-step" in an RNN). Needs to
+ correspond with the order used in `inputs`. OR, can be a
+ single Tensor, of shape [num_uses * batch_size, embedding_size], which
+ is a reshaped version of a Tensor of shape [num_uses, batch_size,
+ embedding_size].
+ num_uses: int or None. The number uses/time-steps in the graph where the
+ layer appears. Only needed if both inputs and outputs are given in the
+ single Tensor format. (Default: None)
+ approx: str or None. If not None must by "kron_indep". The Fisher
+ approximation to use. If None the default value is used.
+ (Default: None)
+ reuse: bool or str. If True, this adds `inputs` and `outputs` as an
+ additional mini-batch/tower of data to use when estimating the Fisher
+ block for this layer (which must have already been registered). If
+ "VARIABLE_SCOPE", use tf.get_variable_scope().reuse. (Note that the
+ word `use` here has a completely different meaning to "use in the graph"
+ as it pertains to the `inputs`, `outputs`, and `num_uses` arguments.)
+ (Default: "VARIABLE_SCOPE")
+
+ Raises:
+ ValueError: For improper value to `approx`.
+ KeyError: If reuse == True but no FisherBlock found for `params`.
+ ValueError: If reuse == True and FisherBlock found but of the wrong type.
+ """
+ block_type, approx = self._get_block_type(
+ params, approx, self.default_embedding_multi_approximation,
+ _EMBEDDING_MULTI_APPROX_TO_BLOCK_TYPES)
+
+ if isinstance(params, (tuple, list)):
+ raise ValueError("Bias not supported.")
+ vocab_size = int(params.shape[0])
+
+ block = self.register_block(
+ params, block_type(self, vocab_size, num_uses=num_uses), reuse=reuse)
+ block.register_additional_tower(inputs, outputs)
+
+ if isinstance(inputs, (tuple, list)):
+ self._add_uses(params, len(inputs))
+ else:
+ self._add_uses(params, 1)
+
+ def register_categorical_predictive_distribution(self,
+ logits,
+ seed=None,
+ targets=None,
+ name=None,
+ reuse=VARIABLE_SCOPE):
+ """Registers a categorical predictive distribution.
+
+ Args:
+ logits: The logits of the distribution (i.e. its parameters).
+ seed: The seed for the RNG (for debugging) (Default: None)
+ targets: (OPTIONAL) The targets for the loss function. Only required if
+ one wants to call total_loss() instead of total_sampled_loss().
+ total_loss() is required, for example, to estimate the
+ "empirical Fisher" (instead of the true Fisher).
+ (Default: None)
+ name: (OPTIONAL) str or None. Unique name for this loss function. If None,
+ a new name is generated. (Default: None)
+ reuse: bool or str. If True, this adds `logits` as an additional
+ mini-batch/tower of inputs to the loss-function/predictive distribution
+ (which must have already been registered). If "VARIABLE_SCOPE", use
+ tf.get_variable_scope().reuse. (Default: "VARIABLE_SCOPE")
+ """
+ loss = lf.CategoricalLogitsNegativeLogProbLoss(logits, targets=targets,
+ seed=seed)
+ self.register_loss_function(loss, logits,
+ "categorical_predictive_distribution",
+ name=name, reuse=reuse)
+
+ def register_normal_predictive_distribution(self,
+ mean,
+ var=0.5,
+ seed=None,
+ targets=None,
+ name=None,
+ reuse=VARIABLE_SCOPE):
+ """Registers a normal predictive distribution.
+
+ Args:
+ mean: The mean vector defining the distribution.
+ var: The variance (must be a scalar). Note that the default value of
+ 0.5 corresponds to a standard squared error loss (target -
+ prediction)**2. If your squared error loss is of the form
+ 0.5*(target - prediction)**2 you should use var=1.0. (Default: 0.5)
+ seed: The seed for the RNG (for debugging) (Default: None)
+ targets: (OPTIONAL) The targets for the loss function. Only required if
+ one wants to call total_loss() instead of total_sampled_loss().
+ total_loss() is required, for example, to estimate the
+ "empirical Fisher" (instead of the true Fisher).
+ (Default: None)
+ name: (OPTIONAL) str or None. Unique name for this loss function. If None,
+ a new name is generated. (Default: None)
+ reuse: bool or str. If True, this adds `mean` and `var` as an additional
+ mini-batch/tower of inputs to the loss-function/predictive distribution
+ (which must have already been registered). If "VARIABLE_SCOPE", use
+ tf.get_variable_scope().reuse. (Default: "VARIABLE_SCOPE")
+ """
+ loss = lf.NormalMeanNegativeLogProbLoss(mean, var, targets=targets,
+ seed=seed)
+ self.register_loss_function(loss, mean,
+ "normal_predictive_distribution",
+ name=name, reuse=reuse)
+
+ def register_multi_bernoulli_predictive_distribution(self,
+ logits,
+ seed=None,
+ targets=None,
+ name=None,
+ reuse=VARIABLE_SCOPE):
+ """Registers a multi-Bernoulli predictive distribution.
+
+ Args:
+ logits: The logits of the distribution (i.e. its parameters).
+ seed: The seed for the RNG (for debugging) (Default: None)
+ targets: (OPTIONAL) The targets for the loss function. Only required if
+ one wants to call total_loss() instead of total_sampled_loss().
+ total_loss() is required, for example, to estimate the
+ "empirical Fisher" (instead of the true Fisher).
+ (Default: None)
+ name: (OPTIONAL) str or None. Unique name for this loss function. If None,
+ a new name is generated. (Default: None)
+ reuse: bool or str. If True, this adds `logits` as an additional
+ mini-batch/tower of inputs to the loss-function/predictive distribution
+ (which must have already been registered). If "VARIABLE_SCOPE", use
+ tf.get_variable_scope().reuse. (Default: "VARIABLE_SCOPE")
+ """
+ loss = lf.MultiBernoulliNegativeLogProbLoss(logits, targets=targets,
+ seed=seed)
+ self.register_loss_function(loss, logits,
+ "multi_bernoulli_predictive_distribution",
+ name=name, reuse=reuse)
+
+ def make_or_get_factor(self, cls, args):
+ """Insert `cls(args)` into 'self.fisher_factors` if not already present.
+
+ Wraps constructor in `tf.variable_scope()` to ensure variables constructed
+ in `cls.__init__` are placed under this LayerCollection's scope.
+
+ Args:
+ cls: Class that implements FisherFactor.
+ args: Tuple of arguments to pass into `cls's constructor. Must be
+ hashable.
+
+ Returns:
+ Instance of `cls` found in self.fisher_factors.
+ """
+ try:
+ hash(args)
+ except TypeError:
+ raise TypeError(
+ ("Unable to use (cls, args) = ({}, {}) as a key in "
+ "LayerCollection.fisher_factors. The pair cannot be hashed.").format(
+ cls, args))
+
+ key = cls, args
+ if key not in self.fisher_factors:
+ with variable_scope.variable_scope(self._var_scope):
+ self.fisher_factors[key] = cls(*args)
+ return self.fisher_factors[key]
+
+ @contextmanager
+ def as_default(self):
+ """Sets this LayerCollection as the default."""
+ set_default_layer_collection(self)
+ yield
+ set_default_layer_collection(None)
diff --git a/tensorflow/contrib/kfac/python/ops/layer_collection_lib.py b/tensorflow/contrib/kfac/python/ops/layer_collection_lib.py
new file mode 100644
index 0000000000..9f46853807
--- /dev/null
+++ b/tensorflow/contrib/kfac/python/ops/layer_collection_lib.py
@@ -0,0 +1,46 @@
+# Copyright 2017 The TensorFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+"""Registry for layers and their parameters/variables.
+
+This represents the collection of all layers in the approximate Fisher
+information matrix to which a particular FisherBlock may belong. That is, we
+might have several layer collections for one TF graph (if we have multiple K-FAC
+optimizers being used, for example.)
+"""
+
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+# pylint: disable=unused-import,line-too-long,wildcard-import
+from tensorflow.contrib.kfac.python.ops.layer_collection import *
+from tensorflow.python.util.all_util import remove_undocumented
+# pylint: enable=unused-import,line-too-long,wildcard-import
+
+_allowed_symbols = [
+ "get_default_layer_collection",
+ "set_default_layer_collection",
+ "LayerParametersDict",
+ "LayerCollection",
+ "APPROX_KRONECKER_NAME",
+ "APPROX_DIAGONAL_NAME",
+ "APPROX_FULL_NAME",
+ "VARIABLE_SCOPE",
+ "APPROX_KRONECKER_INDEP_NAME",
+ "APPROX_KRONECKER_SERIES_1_NAME",
+ "APPROX_KRONECKER_SERIES_2_NAME"
+]
+
+remove_undocumented(__name__, allowed_exception_list=_allowed_symbols)
diff --git a/tensorflow/contrib/kfac/python/ops/linear_operator.py b/tensorflow/contrib/kfac/python/ops/linear_operator.py
new file mode 100644
index 0000000000..61cb955ae8
--- /dev/null
+++ b/tensorflow/contrib/kfac/python/ops/linear_operator.py
@@ -0,0 +1,95 @@
+# Copyright 2018 The TensorFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+"""SmartMatrices definitions."""
+
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+from tensorflow.contrib.kfac.python.ops import utils
+from tensorflow.python.framework import ops
+from tensorflow.python.ops import math_ops
+from tensorflow.python.ops.linalg import linalg
+from tensorflow.python.ops.linalg import linalg_impl
+from tensorflow.python.ops.linalg import linear_operator_util as lou
+
+
+class LinearOperatorExtras(object): # pylint: disable=missing-docstring
+
+ def matmul(self, x, adjoint=False, adjoint_arg=False, name="matmul"):
+
+ with self._name_scope(name, values=[x]):
+ if isinstance(x, ops.IndexedSlices):
+ return self._matmul_sparse(x, adjoint=adjoint, adjoint_arg=adjoint_arg)
+
+ x = ops.convert_to_tensor(x, name="x")
+ self._check_input_dtype(x)
+
+ self_dim = -2 if adjoint else -1
+ arg_dim = -1 if adjoint_arg else -2
+ self.shape[self_dim].assert_is_compatible_with(x.get_shape()[arg_dim])
+
+ return self._matmul(x, adjoint=adjoint, adjoint_arg=adjoint_arg)
+
+ def matmul_right(self, x, adjoint=False, adjoint_arg=False, name="matmul"):
+
+ with self._name_scope(name, values=[x]):
+
+ if isinstance(x, ops.IndexedSlices):
+ return self._matmul_right_sparse(
+ x, adjoint=adjoint, adjoint_arg=adjoint_arg)
+
+ x = ops.convert_to_tensor(x, name="x")
+ self._check_input_dtype(x)
+
+ self_dim = -1 if adjoint else -2
+ arg_dim = -2 if adjoint_arg else -1
+ self.shape[self_dim].assert_is_compatible_with(x.get_shape()[arg_dim])
+
+ return self._matmul_right(x, adjoint=adjoint, adjoint_arg=adjoint_arg)
+
+
+class LinearOperatorFullMatrix(LinearOperatorExtras,
+ linalg.LinearOperatorFullMatrix):
+
+ # TODO(b/78117889) Remove this definition once core LinearOperator
+ # has _matmul_right.
+ def _matmul_right(self, x, adjoint=False, adjoint_arg=False):
+ return lou.matmul_with_broadcast(
+ x, self._matrix, adjoint_a=adjoint_arg, adjoint_b=adjoint)
+
+ def _matmul_sparse(self, x, adjoint=False, adjoint_arg=False):
+ raise NotImplementedError
+
+ def _matmul_right_sparse(self, x, adjoint=False, adjoint_arg=False):
+ assert not adjoint and not adjoint_arg
+ return utils.matmul_sparse_dense(x, self._matrix)
+
+
+class LinearOperatorDiag(LinearOperatorExtras, # pylint: disable=missing-docstring
+ linalg.LinearOperatorDiag):
+
+ def _matmul_right(self, x, adjoint=False, adjoint_arg=False):
+ diag_mat = math_ops.conj(self._diag) if adjoint else self._diag
+ x = linalg_impl.adjoint(x) if adjoint_arg else x
+ return diag_mat * x
+
+ def _matmul_sparse(self, x, adjoint=False, adjoint_arg=False):
+ diag_mat = math_ops.conj(self._diag) if adjoint else self._diag
+ assert not adjoint_arg
+ return utils.matmul_diag_sparse(diag_mat, x)
+
+ def _matmul_right_sparse(self, x, adjoint=False, adjoint_arg=False):
+ raise NotImplementedError
diff --git a/tensorflow/contrib/kfac/python/ops/loss_functions.py b/tensorflow/contrib/kfac/python/ops/loss_functions.py
new file mode 100644
index 0000000000..c8cebc42cb
--- /dev/null
+++ b/tensorflow/contrib/kfac/python/ops/loss_functions.py
@@ -0,0 +1,754 @@
+# Copyright 2017 The TensorFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+"""Loss functions to be used by LayerCollection."""
+
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+import abc
+
+import six
+
+from tensorflow.contrib.distributions.python.ops import onehot_categorical
+from tensorflow.python.framework import tensor_shape
+from tensorflow.python.ops import array_ops
+from tensorflow.python.ops import math_ops
+from tensorflow.python.ops.distributions import bernoulli
+from tensorflow.python.ops.distributions import categorical
+from tensorflow.python.ops.distributions import normal
+
+
+@six.add_metaclass(abc.ABCMeta)
+class LossFunction(object):
+ """Abstract base class for loss functions.
+
+ Note that unlike typical loss functions used in neural networks these are
+ summed and not averaged across cases in the batch, since this is what the
+ users of this class (FisherEstimator and MatrixVectorProductComputer) will
+ be expecting. The implication of this is that you will may want to
+ normalize things like Fisher-vector products by the batch size when you
+ use this class. It depends on the use case.
+ """
+
+ @abc.abstractproperty
+ def targets(self):
+ """The targets being predicted by the model.
+
+ Returns:
+ None or Tensor of appropriate shape for calling self._evaluate() on.
+ """
+ pass
+
+ @abc.abstractproperty
+ def inputs(self):
+ """The inputs to the loss function (excluding the targets)."""
+ pass
+
+ def evaluate(self):
+ """Evaluate the loss function on the targets."""
+ if self.targets is not None:
+ # We treat the targets as "constant". It's only the inputs that get
+ # "back-propped" through.
+ return self._evaluate(array_ops.stop_gradient(self.targets))
+ else:
+ raise Exception("Cannot evaluate losses with unspecified targets.")
+
+ @abc.abstractmethod
+ def _evaluate(self, targets):
+ """Evaluates the negative log probability of the targets.
+
+ Args:
+ targets: Tensor that distribution can calculate log_prob() of.
+
+ Returns:
+ negative log probability of each target, summed across all targets.
+ """
+ pass
+
+ @abc.abstractmethod
+ def multiply_hessian(self, vector):
+ """Right-multiply a vector by the Hessian.
+
+ Here the 'Hessian' is the Hessian matrix (i.e. matrix of 2nd-derivatives)
+ of the loss function with respect to its inputs.
+
+ Args:
+ vector: The vector to multiply. Must be the same shape(s) as the
+ 'inputs' property.
+
+ Returns:
+ The vector right-multiplied by the Hessian. Will be of the same shape(s)
+ as the 'inputs' property.
+ """
+ pass
+
+ @abc.abstractmethod
+ def multiply_hessian_factor(self, vector):
+ """Right-multiply a vector by a factor B of the Hessian.
+
+ Here the 'Hessian' is the Hessian matrix (i.e. matrix of 2nd-derivatives)
+ of the loss function with respect to its inputs. Typically this will be
+ block-diagonal across different cases in the batch, since the loss function
+ is typically summed across cases.
+
+ Note that B can be any matrix satisfying B * B^T = H where H is the Hessian,
+ but will agree with the one used in the other methods of this class.
+
+ Args:
+ vector: The vector to multiply. Must be of the shape given by the
+ 'hessian_factor_inner_shape' property.
+
+ Returns:
+ The vector right-multiplied by B. Will be of the same shape(s) as the
+ 'inputs' property.
+ """
+ pass
+
+ @abc.abstractmethod
+ def multiply_hessian_factor_transpose(self, vector):
+ """Right-multiply a vector by the transpose of a factor B of the Hessian.
+
+ Here the 'Hessian' is the Hessian matrix (i.e. matrix of 2nd-derivatives)
+ of the loss function with respect to its inputs. Typically this will be
+ block-diagonal across different cases in the batch, since the loss function
+ is typically summed across cases.
+
+ Note that B can be any matrix satisfying B * B^T = H where H is the Hessian,
+ but will agree with the one used in the other methods of this class.
+
+ Args:
+ vector: The vector to multiply. Must be the same shape(s) as the
+ 'inputs' property.
+
+ Returns:
+ The vector right-multiplied by B^T. Will be of the shape given by the
+ 'hessian_factor_inner_shape' property.
+ """
+ pass
+
+ @abc.abstractmethod
+ def multiply_hessian_factor_replicated_one_hot(self, index):
+ """Right-multiply a replicated-one-hot vector by a factor B of the Hessian.
+
+ Here the 'Hessian' is the Hessian matrix (i.e. matrix of 2nd-derivatives)
+ of the loss function with respect to its inputs. Typically this will be
+ block-diagonal across different cases in the batch, since the loss function
+ is typically summed across cases.
+
+ A 'replicated-one-hot' vector means a tensor which, for each slice along the
+ batch dimension (assumed to be dimension 0), is 1.0 in the entry
+ corresponding to the given index and 0 elsewhere.
+
+ Note that B can be any matrix satisfying B * B^T = H where H is the Hessian,
+ but will agree with the one used in the other methods of this class.
+
+ Args:
+ index: A tuple representing in the index of the entry in each slice that
+ is 1.0. Note that len(index) must be equal to the number of elements
+ of the 'hessian_factor_inner_shape' tensor minus one.
+
+ Returns:
+ The vector right-multiplied by B^T. Will be of the same shape(s) as the
+ 'inputs' property.
+ """
+ pass
+
+ @abc.abstractproperty
+ def hessian_factor_inner_shape(self):
+ """The shape of the tensor returned by multiply_hessian_factor."""
+ pass
+
+ @abc.abstractproperty
+ def hessian_factor_inner_static_shape(self):
+ """Static version of hessian_factor_inner_shape."""
+ pass
+
+
+@six.add_metaclass(abc.ABCMeta)
+class NegativeLogProbLoss(LossFunction):
+ """Abstract base class for loss functions that are negative log probs."""
+
+ def __init__(self, seed=None):
+ self._default_seed = seed
+ super(NegativeLogProbLoss, self).__init__()
+
+ @property
+ def inputs(self):
+ return self.params
+
+ @abc.abstractproperty
+ def params(self):
+ """Parameters to the underlying distribution."""
+ pass
+
+ @abc.abstractmethod
+ def multiply_fisher(self, vector):
+ """Right-multiply a vector by the Fisher.
+
+ Args:
+ vector: The vector to multiply. Must be the same shape(s) as the
+ 'inputs' property.
+
+ Returns:
+ The vector right-multiplied by the Fisher. Will be of the same shape(s)
+ as the 'inputs' property.
+ """
+ pass
+
+ @abc.abstractmethod
+ def multiply_fisher_factor(self, vector):
+ """Right-multiply a vector by a factor B of the Fisher.
+
+ Here the 'Fisher' is the Fisher information matrix (i.e. expected outer-
+ product of gradients) with respect to the parameters of the underlying
+ probability distribution (whose log-prob defines the loss). Typically this
+ will be block-diagonal across different cases in the batch, since the
+ distribution is usually (but not always) conditionally iid across different
+ cases.
+
+ Note that B can be any matrix satisfying B * B^T = F where F is the Fisher,
+ but will agree with the one used in the other methods of this class.
+
+ Args:
+ vector: The vector to multiply. Must be of the shape given by the
+ 'fisher_factor_inner_shape' property.
+
+ Returns:
+ The vector right-multiplied by B. Will be of the same shape(s) as the
+ 'inputs' property.
+ """
+ pass
+
+ @abc.abstractmethod
+ def multiply_fisher_factor_transpose(self, vector):
+ """Right-multiply a vector by the transpose of a factor B of the Fisher.
+
+ Here the 'Fisher' is the Fisher information matrix (i.e. expected outer-
+ product of gradients) with respect to the parameters of the underlying
+ probability distribution (whose log-prob defines the loss). Typically this
+ will be block-diagonal across different cases in the batch, since the
+ distribution is usually (but not always) conditionally iid across different
+ cases.
+
+ Note that B can be any matrix satisfying B * B^T = F where F is the Fisher,
+ but will agree with the one used in the other methods of this class.
+
+ Args:
+ vector: The vector to multiply. Must be the same shape(s) as the
+ 'inputs' property.
+
+ Returns:
+ The vector right-multiplied by B^T. Will be of the shape given by the
+ 'fisher_factor_inner_shape' property.
+ """
+ pass
+
+ @abc.abstractmethod
+ def multiply_fisher_factor_replicated_one_hot(self, index):
+ """Right-multiply a replicated-one-hot vector by a factor B of the Fisher.
+
+ Here the 'Fisher' is the Fisher information matrix (i.e. expected outer-
+ product of gradients) with respect to the parameters of the underlying
+ probability distribution (whose log-prob defines the loss). Typically this
+ will be block-diagonal across different cases in the batch, since the
+ distribution is usually (but not always) conditionally iid across different
+ cases.
+
+ A 'replicated-one-hot' vector means a tensor which, for each slice along the
+ batch dimension (assumed to be dimension 0), is 1.0 in the entry
+ corresponding to the given index and 0 elsewhere.
+
+ Note that B can be any matrix satisfying B * B^T = H where H is the Fisher,
+ but will agree with the one used in the other methods of this class.
+
+ Args:
+ index: A tuple representing in the index of the entry in each slice that
+ is 1.0. Note that len(index) must be equal to the number of elements
+ of the 'fisher_factor_inner_shape' tensor minus one.
+
+ Returns:
+ The vector right-multiplied by B. Will be of the same shape(s) as the
+ 'inputs' property.
+ """
+ pass
+
+ @abc.abstractproperty
+ def fisher_factor_inner_shape(self):
+ """The shape of the tensor returned by multiply_fisher_factor."""
+ pass
+
+ @abc.abstractproperty
+ def fisher_factor_inner_static_shape(self):
+ """Static version of fisher_factor_inner_shape."""
+ pass
+
+ @abc.abstractmethod
+ def sample(self, seed):
+ """Sample 'targets' from the underlying distribution."""
+ pass
+
+ def evaluate_on_sample(self, seed=None):
+ """Evaluates the log probability on a random sample.
+
+ Args:
+ seed: int or None. Random seed for this draw from the distribution.
+
+ Returns:
+ Log probability of sampled targets, summed across examples.
+ """
+ if seed is None:
+ seed = self._default_seed
+ # We treat the targets as "constant". It's only the inputs that get
+ # "back-propped" through.
+ return self._evaluate(array_ops.stop_gradient(self.sample(seed)))
+
+
+# TODO(jamesmartens): should this just inherit from object to avoid "diamond"
+# inheritance, or is there a better way?
+class NaturalParamsNegativeLogProbLoss(NegativeLogProbLoss):
+ """Base class for neg log prob losses whose inputs are 'natural' parameters.
+
+ Note that the Hessian and Fisher for natural parameters of exponential-
+ family models are the same, hence the purpose of this class.
+ See here: https://arxiv.org/abs/1412.1193
+
+ 'Natural parameters' are defined for exponential-family models. See for
+ example: https://en.wikipedia.org/wiki/Exponential_family
+ """
+
+ def multiply_hessian(self, vector):
+ return self.multiply_fisher(vector)
+
+ def multiply_hessian_factor(self, vector):
+ return self.multiply_fisher_factor(vector)
+
+ def multiply_hessian_factor_transpose(self, vector):
+ return self.multiply_fisher_factor_transpose(vector)
+
+ def multiply_hessian_factor_replicated_one_hot(self, index):
+ return self.multiply_fisher_factor_replicated_one_hot(index)
+
+ @property
+ def hessian_factor_inner_shape(self):
+ return self.fisher_factor_inner_shape
+
+ @property
+ def hessian_factor_inner_static_shape(self):
+ return self.fisher_factor_inner_shape
+
+
+class DistributionNegativeLogProbLoss(NegativeLogProbLoss):
+ """Base class for neg log prob losses that use the TF Distribution classes."""
+
+ def __init__(self, seed=None):
+ super(DistributionNegativeLogProbLoss, self).__init__(seed=seed)
+
+ @abc.abstractproperty
+ def dist(self):
+ """The underlying tf.distributions.Distribution."""
+ pass
+
+ def _evaluate(self, targets):
+ return -math_ops.reduce_sum(self.dist.log_prob(targets))
+
+ def sample(self, seed):
+ return self.dist.sample(seed=seed)
+
+
+class NormalMeanNegativeLogProbLoss(DistributionNegativeLogProbLoss,
+ NaturalParamsNegativeLogProbLoss):
+ """Neg log prob loss for a normal distribution parameterized by a mean vector.
+
+
+ Note that the covariance is treated as a constant 'var' times the identity.
+ Also note that the Fisher for such a normal distribution with respect the mean
+ parameter is given by:
+
+ F = (1/var) * I
+
+ See for example https://www.ii.pwr.edu.pl/~tomczak/PDF/[JMT]Fisher_inf.pdf.
+ """
+
+ def __init__(self, mean, var=0.5, targets=None, seed=None):
+ self._mean = mean
+ self._var = var
+ self._targets = targets
+ super(NormalMeanNegativeLogProbLoss, self).__init__(seed=seed)
+
+ @property
+ def targets(self):
+ return self._targets
+
+ @property
+ def dist(self):
+ return normal.Normal(loc=self._mean, scale=math_ops.sqrt(self._var))
+
+ @property
+ def params(self):
+ return self._mean
+
+ def multiply_fisher(self, vector):
+ return (1. / self._var) * vector
+
+ def multiply_fisher_factor(self, vector):
+ return self._var**-0.5 * vector
+
+ def multiply_fisher_factor_transpose(self, vector):
+ return self.multiply_fisher_factor(vector) # it's symmetric in this case
+
+ def multiply_fisher_factor_replicated_one_hot(self, index):
+ assert len(index) == 1, "Length of index was {}".format(len(index))
+ ones_slice = array_ops.expand_dims(
+ array_ops.ones(array_ops.shape(self._mean)[:1], dtype=self._mean.dtype),
+ axis=-1)
+ output_slice = self._var**-0.5 * ones_slice
+ return insert_slice_in_zeros(output_slice, 1, int(self._mean.shape[1]),
+ index[0])
+
+ @property
+ def fisher_factor_inner_shape(self):
+ return array_ops.shape(self._mean)
+
+ @property
+ def fisher_factor_inner_static_shape(self):
+ return self._mean.shape
+
+
+class NormalMeanVarianceNegativeLogProbLoss(DistributionNegativeLogProbLoss):
+ """Negative log prob loss for a normal distribution with mean and variance.
+
+ This class parameterizes a multivariate normal distribution with n independent
+ dimensions. Unlike `NormalMeanNegativeLogProbLoss`, this class does not
+ assume the variance is held constant. The Fisher Information for n = 1
+ is given by,
+
+ F = [[1 / variance, 0],
+ [ 0, 0.5 / variance^2]]
+
+ where the parameters of the distribution are concatenated into a single
+ vector as [mean, variance]. For n > 1, the mean parameter vector is
+ concatenated with the variance parameter vector.
+
+ See https://www.ii.pwr.edu.pl/~tomczak/PDF/[JMT]Fisher_inf.pdf for derivation.
+ """
+
+ def __init__(self, mean, variance, targets=None, seed=None):
+ assert len(mean.shape) == 2, "Expect 2D mean tensor."
+ assert len(variance.shape) == 2, "Expect 2D variance tensor."
+ self._mean = mean
+ self._variance = variance
+ self._targets = targets
+ super(NormalMeanVarianceNegativeLogProbLoss, self).__init__(seed=seed)
+
+ @property
+ def targets(self):
+ return self._targets
+
+ @property
+ def dist(self):
+ return normal.Normal(loc=self._mean, scale=math_ops.sqrt(self._variance))
+
+ @property
+ def params(self):
+ return self._mean, self._variance
+
+ def _concat(self, mean, variance):
+ return array_ops.concat([mean, variance], axis=-1)
+
+ def _split(self, params):
+ return array_ops.split(params, 2, axis=-1)
+
+ @property
+ def _fisher_mean(self):
+ return 1. / self._variance
+
+ @property
+ def _fisher_mean_factor(self):
+ return 1. / math_ops.sqrt(self._variance)
+
+ @property
+ def _fisher_var(self):
+ return 1. / (2 * math_ops.square(self._variance))
+
+ @property
+ def _fisher_var_factor(self):
+ return 1. / (math_ops.sqrt(2.) * self._variance)
+
+ def multiply_fisher(self, vecs):
+ mean_vec, var_vec = vecs
+ return (self._fisher_mean * mean_vec, self._fisher_var * var_vec)
+
+ def multiply_fisher_factor(self, vecs):
+ mean_vec, var_vec = self._split(vecs)
+ return (self._fisher_mean_factor * mean_vec,
+ self._fisher_var_factor * var_vec)
+
+ def multiply_fisher_factor_transpose(self, vecs):
+ mean_vec, var_vec = vecs
+ return self._concat(self._fisher_mean_factor * mean_vec,
+ self._fisher_var_factor * var_vec)
+
+ def multiply_fisher_factor_replicated_one_hot(self, index):
+ assert len(index) == 1, "Length of index was {}".format(len(index))
+ index = index[0]
+
+ if index < int(self._mean.shape[-1]):
+ # Index corresponds to mean parameter.
+ mean_slice = self._fisher_mean_factor[:, index]
+ mean_slice = array_ops.expand_dims(mean_slice, axis=-1)
+ mean_output = insert_slice_in_zeros(mean_slice, 1, int(
+ self._mean.shape[1]), index)
+ var_output = array_ops.zeros_like(mean_output)
+ else:
+ index -= int(self._mean.shape[-1])
+ # Index corresponds to variance parameter.
+ var_slice = self._fisher_var_factor[:, index]
+ var_slice = array_ops.expand_dims(var_slice, axis=-1)
+ var_output = insert_slice_in_zeros(var_slice, 1,
+ int(self._variance.shape[1]), index)
+ mean_output = array_ops.zeros_like(var_output)
+
+ return mean_output, var_output
+
+ @property
+ def fisher_factor_inner_shape(self):
+ return array_ops.concat(
+ [
+ array_ops.shape(self._mean)[:-1],
+ 2 * array_ops.shape(self._mean)[-1:]
+ ],
+ axis=0)
+
+ @property
+ def fisher_factor_inner_static_shape(self):
+ shape = self._mean.shape.as_list()
+ return tensor_shape.TensorShape(shape[-1:] + [2 * shape[-1]])
+
+ def multiply_hessian(self, vector):
+ raise NotImplementedError()
+
+ def multiply_hessian_factor(self, vector):
+ raise NotImplementedError()
+
+ def multiply_hessian_factor_transpose(self, vector):
+ raise NotImplementedError()
+
+ def multiply_hessian_factor_replicated_one_hot(self, index):
+ raise NotImplementedError()
+
+ @property
+ def hessian_factor_inner_shape(self):
+ raise NotImplementedError()
+
+ @property
+ def hessian_factor_inner_static_shape(self):
+ raise NotImplementedError()
+
+
+class CategoricalLogitsNegativeLogProbLoss(DistributionNegativeLogProbLoss,
+ NaturalParamsNegativeLogProbLoss):
+ """Neg log prob loss for a categorical distribution parameterized by logits.
+
+
+ Note that the Fisher (for a single case) of a categorical distribution, with
+ respect to the natural parameters (i.e. the logits), is given by:
+
+ F = diag(p) - p*p^T
+
+ where p = softmax(logits). F can be factorized as F = B * B^T where
+
+ B = diag(q) - p*q^T
+
+ where q is the entry-wise square root of p. This is easy to verify using the
+ fact that q^T*q = 1.
+ """
+
+ def __init__(self, logits, targets=None, seed=None):
+ """Instantiates a CategoricalLogitsNegativeLogProbLoss.
+
+ Args:
+ logits: Tensor of shape [batch_size, output_size]. Parameters for
+ underlying distribution.
+ targets: None or Tensor of shape [output_size]. Each elements contains an
+ index in [0, output_size).
+ seed: int or None. Default random seed when sampling.
+ """
+ self._logits = logits
+ self._targets = targets
+ super(CategoricalLogitsNegativeLogProbLoss, self).__init__(seed=seed)
+
+ @property
+ def targets(self):
+ return self._targets
+
+ @property
+ def dist(self):
+ return categorical.Categorical(logits=self._logits)
+
+ @property
+ def _probs(self):
+ return self.dist.probs
+
+ @property
+ def _sqrt_probs(self):
+ return math_ops.sqrt(self._probs)
+
+ @property
+ def params(self):
+ return self._logits
+
+ def multiply_fisher(self, vector):
+ probs = self._probs
+ return vector * probs - probs * math_ops.reduce_sum(
+ vector * probs, axis=-1, keepdims=True)
+
+ def multiply_fisher_factor(self, vector):
+ probs = self._probs
+ sqrt_probs = self._sqrt_probs
+ return sqrt_probs * vector - probs * math_ops.reduce_sum(
+ sqrt_probs * vector, axis=-1, keepdims=True)
+
+ def multiply_fisher_factor_transpose(self, vector):
+ probs = self._probs
+ sqrt_probs = self._sqrt_probs
+ return sqrt_probs * vector - sqrt_probs * math_ops.reduce_sum(
+ probs * vector, axis=-1, keepdims=True)
+
+ def multiply_fisher_factor_replicated_one_hot(self, index):
+ assert len(index) == 1, "Length of index was {}".format(len(index))
+ probs = self._probs
+ sqrt_probs = self._sqrt_probs
+ sqrt_probs_slice = array_ops.expand_dims(sqrt_probs[:, index[0]], -1)
+ padded_slice = insert_slice_in_zeros(sqrt_probs_slice, 1,
+ int(sqrt_probs.shape[1]), index[0])
+ return padded_slice - probs * sqrt_probs_slice
+
+ @property
+ def fisher_factor_inner_shape(self):
+ return array_ops.shape(self._logits)
+
+ @property
+ def fisher_factor_inner_static_shape(self):
+ return self._logits.shape
+
+
+class MultiBernoulliNegativeLogProbLoss(DistributionNegativeLogProbLoss,
+ NaturalParamsNegativeLogProbLoss):
+ """Neg log prob loss for multiple Bernoulli distributions param'd by logits.
+
+ Represents N independent Bernoulli distributions where N = len(logits). Its
+ Fisher Information matrix is given by,
+
+ F = diag(p * (1-p))
+ p = sigmoid(logits)
+
+ As F is diagonal with positive entries, its factor B is,
+
+ B = diag(sqrt(p * (1-p)))
+ """
+
+ def __init__(self, logits, targets=None, seed=None):
+ self._logits = logits
+ self._targets = targets
+ super(MultiBernoulliNegativeLogProbLoss, self).__init__(seed=seed)
+
+ @property
+ def targets(self):
+ return self._targets
+
+ @property
+ def dist(self):
+ return bernoulli.Bernoulli(logits=self._logits)
+
+ @property
+ def _probs(self):
+ return self.dist.probs
+
+ @property
+ def params(self):
+ return self._logits
+
+ def multiply_fisher(self, vector):
+ return self._probs * (1 - self._probs) * vector
+
+ def multiply_fisher_factor(self, vector):
+ return math_ops.sqrt(self._probs * (1 - self._probs)) * vector
+
+ def multiply_fisher_factor_transpose(self, vector):
+ return self.multiply_fisher_factor(vector) # it's symmetric in this case
+
+ def multiply_fisher_factor_replicated_one_hot(self, index):
+ assert len(index) == 1, "Length of index was {}".format(len(index))
+ probs_slice = array_ops.expand_dims(self._probs[:, index[0]], -1)
+ output_slice = math_ops.sqrt(probs_slice * (1 - probs_slice))
+ return insert_slice_in_zeros(output_slice, 1, int(self._logits.shape[1]),
+ index[0])
+
+ @property
+ def fisher_factor_inner_shape(self):
+ return array_ops.shape(self._logits)
+
+ @property
+ def fisher_factor_inner_static_shape(self):
+ return self._logits.shape
+
+
+def insert_slice_in_zeros(slice_to_insert, dim, dim_size, position):
+ """Inserts slice into a larger tensor of zeros.
+
+ Forms a new tensor which is the same shape as slice_to_insert, except that
+ the dimension given by 'dim' is expanded to the size given by 'dim_size'.
+ 'position' determines the position (index) at which to insert the slice within
+ that dimension.
+
+ Assumes slice_to_insert.shape[dim] = 1.
+
+ Args:
+ slice_to_insert: The slice to insert.
+ dim: The dimension which to expand with zeros.
+ dim_size: The new size of the 'dim' dimension.
+ position: The position of 'slice_to_insert' in the new tensor.
+
+ Returns:
+ The new tensor.
+
+ Raises:
+ ValueError: If the slice's shape at the given dim is not 1.
+ """
+ slice_shape = slice_to_insert.shape
+ if slice_shape[dim] != 1:
+ raise ValueError("Expected slice_to_insert.shape to have {} dim of 1, but "
+ "was {}".format(dim, slice_to_insert.shape[dim]))
+
+ before = [0] * int(len(slice_shape))
+ after = before[:]
+ before[dim] = position
+ after[dim] = dim_size - position - 1
+
+ return array_ops.pad(slice_to_insert, list(zip(before, after)))
+
+
+class OnehotCategoricalLogitsNegativeLogProbLoss(
+ CategoricalLogitsNegativeLogProbLoss):
+ """Neg log prob loss for a categorical distribution with onehot targets.
+
+ Identical to CategoricalLogitsNegativeLogProbLoss except that the underlying
+ distribution is OneHotCategorical as opposed to Categorical.
+ """
+
+ @property
+ def dist(self):
+ return onehot_categorical.OneHotCategorical(logits=self._logits)
diff --git a/tensorflow/contrib/kfac/python/ops/loss_functions_lib.py b/tensorflow/contrib/kfac/python/ops/loss_functions_lib.py
new file mode 100644
index 0000000000..4279cb2792
--- /dev/null
+++ b/tensorflow/contrib/kfac/python/ops/loss_functions_lib.py
@@ -0,0 +1,39 @@
+# Copyright 2017 The TensorFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+"""Loss functions to be used by LayerCollection."""
+
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+# pylint: disable=unused-import,line-too-long,wildcard-import
+from tensorflow.contrib.kfac.python.ops.loss_functions import *
+from tensorflow.python.util.all_util import remove_undocumented
+# pylint: enable=unused-import,line-too-long,wildcard-import
+
+_allowed_symbols = [
+ "LossFunction",
+ "NegativeLogProbLoss",
+ "NaturalParamsNegativeLogProbLoss",
+ "DistributionNegativeLogProbLoss",
+ "NormalMeanNegativeLogProbLoss",
+ "NormalMeanVarianceNegativeLogProbLoss",
+ "CategoricalLogitsNegativeLogProbLoss",
+ "OnehotCategoricalLogitsNegativeLogProbLoss",
+ "MultiBernoulliNegativeLogProbLoss",
+ "insert_slice_in_zeros",
+]
+
+remove_undocumented(__name__, allowed_exception_list=_allowed_symbols)
diff --git a/tensorflow/contrib/kfac/python/ops/op_queue.py b/tensorflow/contrib/kfac/python/ops/op_queue.py
new file mode 100644
index 0000000000..b6d9d37a31
--- /dev/null
+++ b/tensorflow/contrib/kfac/python/ops/op_queue.py
@@ -0,0 +1,69 @@
+# Copyright 2017 The TensorFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+"""Helper for choosing which op to run next in a distributed setting."""
+
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+from tensorflow.python.data.ops import dataset_ops
+from tensorflow.python.framework import ops as tf_ops
+
+
+class OpQueue(object):
+ """Class for choosing which Op to run next.
+
+ Constructs an infinitely repeating sequence of Ops in shuffled order.
+
+ In K-FAC, this can be used to distribute inverse update operations among
+ workers.
+ """
+
+ def __init__(self, ops, seed=None):
+ """Initializes an OpQueue.
+
+ Args:
+ ops: list of TensorFlow Ops. Ops to be selected from. All workers must
+ initialize with the same set of ops.
+ seed: int or None. Random seed used when shuffling order of ops.
+ """
+ self._ops_by_name = {op.name: op for op in ops}
+
+ # Construct a (shuffled) Dataset with Op names.
+ op_names = tf_ops.convert_to_tensor(list(sorted(op.name for op in ops)))
+ op_names_dataset = (dataset_ops.Dataset.from_tensor_slices(op_names)
+ .shuffle(len(ops), seed=seed).repeat())
+ self._next_op_name = op_names_dataset.make_one_shot_iterator().get_next()
+
+ @property
+ def ops(self):
+ """Ops this OpQueue can return in next_op()."""
+ return self._ops_by_name.values()
+
+ def next_op(self, sess):
+ """Chooses which op to run next.
+
+ Note: This call will make a call to sess.run().
+
+ Args:
+ sess: tf.Session.
+
+ Returns:
+ Next Op chosen from 'ops'.
+ """
+ # In Python 3, type(next_op_name) == bytes. Calling bytes.decode('ascii')
+ # returns a str.
+ next_op_name = sess.run(self._next_op_name).decode('ascii')
+ return self._ops_by_name[next_op_name]
diff --git a/tensorflow/contrib/kfac/python/ops/op_queue_lib.py b/tensorflow/contrib/kfac/python/ops/op_queue_lib.py
new file mode 100644
index 0000000000..09c9a4ab33
--- /dev/null
+++ b/tensorflow/contrib/kfac/python/ops/op_queue_lib.py
@@ -0,0 +1,30 @@
+# Copyright 2017 The TensorFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+"""Helper for choosing which op to run next in a distributed setting."""
+
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+# pylint: disable=unused-import,line-too-long,wildcard-import
+from tensorflow.contrib.kfac.python.ops.op_queue import *
+from tensorflow.python.util.all_util import remove_undocumented
+# pylint: enable=unused-import,line-too-long,wildcard-import
+
+_allowed_symbols = [
+ 'OpQueue',
+]
+
+remove_undocumented(__name__, allowed_exception_list=_allowed_symbols)
diff --git a/tensorflow/contrib/kfac/python/ops/optimizer.py b/tensorflow/contrib/kfac/python/ops/optimizer.py
new file mode 100644
index 0000000000..38605259b5
--- /dev/null
+++ b/tensorflow/contrib/kfac/python/ops/optimizer.py
@@ -0,0 +1,727 @@
+# Copyright 2017 The TensorFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+"""The KFAC optimizer."""
+
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+import warnings
+
+# pylint disable=long-line
+from tensorflow.contrib.kfac.python.ops import curvature_matrix_vector_products as cmvp
+from tensorflow.contrib.kfac.python.ops import estimator as est
+# pylint enable=long-line
+
+from tensorflow.python.framework import dtypes
+from tensorflow.python.framework import ops
+from tensorflow.python.ops import array_ops
+from tensorflow.python.ops import control_flow_ops
+from tensorflow.python.ops import linalg_ops
+from tensorflow.python.ops import math_ops
+from tensorflow.python.ops import state_ops
+from tensorflow.python.ops import variable_scope
+from tensorflow.python.ops import variables as tf_variables
+from tensorflow.python.training import gradient_descent
+
+
+class KfacOptimizer(gradient_descent.GradientDescentOptimizer):
+ """The KFAC Optimizer (https://arxiv.org/abs/1503.05671)."""
+
+ def __init__(self,
+ learning_rate,
+ cov_ema_decay,
+ damping,
+ layer_collection,
+ var_list=None,
+ momentum=0.9,
+ momentum_type="regular",
+ norm_constraint=None,
+ name="KFAC",
+ estimation_mode="gradients",
+ colocate_gradients_with_ops=True,
+ batch_size=None,
+ placement_strategy=None,
+ **kwargs):
+ """Initializes the KFAC optimizer with the given settings.
+
+ Args:
+ learning_rate: The base learning rate for the optimizer. Should probably
+ be set to 1.0 when using momentum_type = 'qmodel', but can still be
+ set lowered if desired (effectively lowering the trust in the
+ quadratic model.)
+ cov_ema_decay: The decay factor used when calculating the covariance
+ estimate moving averages.
+ damping: The damping factor used to stabilize training due to errors in
+ the local approximation with the Fisher information matrix, and to
+ regularize the update direction by making it closer to the gradient.
+ If damping is adapted during training then this value is used for
+ initializing damping variable.
+ (Higher damping means the update looks more like a standard gradient
+ update - see Tikhonov regularization.)
+ layer_collection: The layer collection object, which holds the fisher
+ blocks, Kronecker factors, and losses associated with the
+ graph. The layer_collection cannot be modified after KfacOptimizer's
+ initialization.
+ var_list: Optional list or tuple of variables to train. Defaults to the
+ list of variables collected in the graph under the key
+ `GraphKeys.TRAINABLE_VARIABLES`.
+ momentum: The momentum decay constant to use. Only applies when
+ momentum_type is 'regular' or 'adam'. (Default: 0.9)
+ momentum_type: The type of momentum to use in this optimizer, one of
+ 'regular', 'adam', or 'qmodel'. (Default: 'regular')
+ norm_constraint: float or Tensor. If specified, the update is scaled down
+ so that its approximate squared Fisher norm v^T F v is at most the
+ specified value. May only be used with momentum type 'regular'.
+ (Default: None)
+ name: The name for this optimizer. (Default: 'KFAC')
+ estimation_mode: The type of estimator to use for the Fishers. Can be
+ 'gradients', 'empirical', 'curvature_propagation', or 'exact'.
+ (Default: 'gradients'). See the doc-string for FisherEstimator for
+ more a more detailed description of these options.
+ colocate_gradients_with_ops: Whether we should request gradients we
+ compute in the estimator be colocated with their respective ops.
+ (Default: True)
+ batch_size: The size of the mini-batch. Only needed when momentum_type
+ == 'qmodel' or when automatic adjustment is used. (Default: None)
+ placement_strategy: string, Device placement strategy used when creating
+ covariance variables, covariance ops, and inverse ops.
+ (Default: `None`)
+ **kwargs: Arguments to be passed to specific placement
+ strategy mixin. Check `placement.RoundRobinPlacementMixin` for example.
+
+ Raises:
+ ValueError: If the momentum type is unsupported.
+ ValueError: If clipping is used with momentum type other than 'regular'.
+ ValueError: If no losses have been registered with layer_collection.
+ ValueError: If momentum is non-zero and momentum_type is not 'regular'
+ or 'adam'.
+ """
+ warnings.warn(
+ "third_party.tensorflow.contrib.kfac is deprecated."
+ "This will be removed on 15-07-2018. Check README for further details.",
+ DeprecationWarning)
+ # Parameters to be passed to the Fisher estimator:
+ self._variables = var_list or tf_variables.trainable_variables
+ self._cov_ema_decay = cov_ema_decay
+ self._layers = layer_collection
+ self._estimation_mode = estimation_mode
+ self._colocate_gradients_with_ops = colocate_gradients_with_ops
+
+ # The below parameters are required only if damping needs to be adapted.
+ # These parameters can be set by calling
+ # set_damping_adaptation_params() explicitly.
+ self._damping_adaptation_decay = 0.95
+ self._damping_adaptation_interval = 5
+ # Check section 6.5 KFAC paper. omega(1) = pow(damping decay, interval)
+ self._omega = (
+ self._damping_adaptation_decay**self._damping_adaptation_interval)
+ self._adapt_damping = False
+ self._min_damping = 1e-5
+ self._prev_train_batch = None
+ self._is_chief = False
+ self._loss_fn = None
+ self._damping_constant = damping
+ self._damping = None
+ self._rho = None
+ self._prev_loss = None
+ self._q_model_change = None
+ self._update_damping_op = None
+
+ momentum_type = momentum_type.lower()
+ legal_momentum_types = ["regular", "adam", "qmodel"]
+
+ if momentum_type not in legal_momentum_types:
+ raise ValueError("Unsupported momentum type {}. Must be one of {}."
+ .format(momentum_type, legal_momentum_types))
+ if momentum_type != "regular" and norm_constraint is not None:
+ raise ValueError("Update clipping is only supported with momentum "
+ "type 'regular'.")
+ if momentum_type not in ["regular", "adam"] and momentum != 0:
+ raise ValueError("Momentum must be unspecified if using a momentum_type "
+ "other than 'regular' or 'adam'.")
+
+ # Extra parameters of the optimizer
+ self._momentum = momentum
+ self._momentum_type = momentum_type
+ self._norm_constraint = norm_constraint
+ self._batch_size = batch_size
+ self._placement_strategy = placement_strategy
+
+ with variable_scope.variable_scope(name):
+ self._fisher_est = est.make_fisher_estimator(
+ placement_strategy=placement_strategy,
+ variables=self._variables,
+ cov_ema_decay=self._cov_ema_decay,
+ damping=self.damping,
+ layer_collection=self._layers,
+ exps=(-1,),
+ estimation_mode=self._estimation_mode,
+ colocate_gradients_with_ops=self._colocate_gradients_with_ops,
+ **kwargs)
+
+ super(KfacOptimizer, self).__init__(learning_rate, name=name)
+
+ def set_damping_adaptation_params(self,
+ is_chief,
+ prev_train_batch,
+ loss_fn,
+ min_damping=1e-5,
+ damping_adaptation_decay=0.99,
+ damping_adaptation_interval=5):
+ """Sets parameters required to adapt damping during training.
+
+ When called, enables damping adaptation according to the Levenberg-Marquardt
+ style rule described in Section 6.5 of "Optimizing Neural Networks with
+ Kronecker-factored Approximate Curvature".
+
+ Note that this function creates Tensorflow variables which store a few
+ scalars and are accessed by the ops which update the damping (as part
+ of the training op returned by the minimize() method).
+
+ Args:
+ is_chief: `Boolean`, `True` if the worker is chief.
+ prev_train_batch: Training data used to minimize loss in the previous
+ step. This will be used to evaluate loss by calling
+ `loss_fn(prev_train_batch)`.
+ loss_fn: `function` that takes as input training data tensor and returns
+ a scalar loss.
+ min_damping: `float`(Optional), Minimum value the damping parameter
+ can take. Default value 1e-5.
+ damping_adaptation_decay: `float`(Optional), The `damping` parameter is
+ multiplied by the `damping_adaptation_decay` every
+ `damping_adaptation_interval` number of iterations. Default value 0.99.
+ damping_adaptation_interval: `int`(Optional), Number of steps in between
+ updating the `damping` parameter. Default value 5.
+
+ Raises:
+ ValueError: If `set_damping_adaptation_params` is already called and the
+ the `adapt_damping` is `True`.
+ """
+ if self._adapt_damping:
+ raise ValueError("Damping adaptation parameters already set.")
+
+ with variable_scope.variable_scope(self.get_name()):
+ self._adapt_damping = True
+ self._is_chief = is_chief
+ self._prev_train_batch = prev_train_batch
+ self._loss_fn = loss_fn
+ self._damping_adaptation_decay = damping_adaptation_decay
+ self._damping_adaptation_interval = damping_adaptation_interval
+ self._omega = (
+ self._damping_adaptation_decay**self._damping_adaptation_interval)
+ self._min_damping = min_damping
+
+ self._rho = variable_scope.get_variable(
+ "rho", shape=(), dtype=dtypes.float32, trainable=False) # LM ratio.
+ self._prev_loss = variable_scope.get_variable(
+ "prev_loss", shape=(), dtype=dtypes.float32, trainable=False)
+ self._q_model_change = variable_scope.get_variable(
+ "q_model_change", shape=(), dtype=dtypes.float32, trainable=False)
+ self._damping = variable_scope.get_variable(
+ "damping", initializer=self._damping_constant, trainable=False)
+
+ @property
+ def variables(self):
+ return self._fisher_est.variables
+
+ @property
+ def damping(self):
+ if self._damping:
+ return self._damping
+ else:
+ return self._damping_constant
+
+ @property
+ def damping_adaptation_interval(self):
+ return self._damping_adaptation_interval
+
+ def make_vars_and_create_op_thunks(self):
+ """Make vars and create op thunks.
+
+ Returns:
+ cov_update_thunks: List of cov update thunks. Corresponds one-to-one with
+ the list of factors given by the "factors" property.
+ inv_update_thunks: List of inv update thunks. Corresponds one-to-one with
+ the list of factors given by the "factors" property.
+ """
+ scope = self.get_name() + "/" + self._fisher_est.name
+ return self._fisher_est.make_vars_and_create_op_thunks(scope=scope)
+
+ def create_ops_and_vars_thunks(self):
+ """Create thunks that make the ops and vars on demand.
+
+ This function returns 4 lists of thunks: cov_variable_thunks,
+ cov_update_thunks, inv_variable_thunks, and inv_update_thunks.
+
+ The length of each list is the number of factors and the i-th element of
+ each list corresponds to the i-th factor (given by the "factors" property).
+
+ Note that the execution of these thunks must happen in a certain
+ partial order. The i-th element of cov_variable_thunks must execute
+ before the i-th element of cov_update_thunks (and also the i-th element
+ of inv_update_thunks). Similarly, the i-th element of inv_variable_thunks
+ must execute before the i-th element of inv_update_thunks.
+
+ TL;DR (oversimplified): Execute the thunks according to the order that
+ they are returned.
+
+ Returns:
+ cov_variable_thunks: A list of thunks that make the cov variables.
+ cov_update_thunks: A list of thunks that make the cov update ops.
+ inv_variable_thunks: A list of thunks that make the inv variables.
+ inv_update_thunks: A list of thunks that make the inv update ops.
+ """
+ scope = self.get_name() + "/" + self._fisher_est.name
+ return self._fisher_est.create_ops_and_vars_thunks(scope=scope)
+
+ def minimize(self, *args, **kwargs):
+ # Should this variable scope encompass everything below? Or will the super-
+ # class make another copy of the same name scope?
+ with variable_scope.variable_scope(self.get_name()):
+ kwargs["var_list"] = kwargs.get("var_list") or self.variables
+ if set(kwargs["var_list"]) != set(self.variables):
+ raise ValueError("var_list doesn't match with set of Fisher-estimating "
+ "variables.")
+ if self._adapt_damping and self._is_chief:
+ global_step = kwargs.get("global_step", None)
+ if not global_step:
+ raise KeyError("global_step needs to be passed to optimizer.minimize "
+ "if damping parameter is adapted.")
+ update_damping_op = self._update_damping(self._prev_train_batch,
+ global_step)
+ with ops.control_dependencies([update_damping_op]):
+ loss = args[0]
+ loss_assign_op = state_ops.assign(self._prev_loss, loss)
+ train_op = super(KfacOptimizer, self).minimize(*args, **kwargs)
+ return control_flow_ops.group(loss_assign_op, train_op)
+ else:
+ return super(KfacOptimizer, self).minimize(*args, **kwargs)
+
+ def compute_gradients(self, *args, **kwargs):
+ # args[1] could be our var_list
+ if len(args) > 1:
+ var_list = args[1]
+ else:
+ kwargs["var_list"] = kwargs.get("var_list") or self.variables
+ var_list = kwargs["var_list"]
+
+ if set(var_list) != set(self.variables):
+ raise ValueError("var_list doesn't match with set of Fisher-estimating "
+ "variables.")
+ return super(KfacOptimizer, self).compute_gradients(*args, **kwargs)
+
+ def apply_gradients(self, grads_and_vars, *args, **kwargs):
+ """Applies gradients to variables.
+
+ Args:
+ grads_and_vars: List of (gradient, variable) pairs.
+ *args: Additional arguments for super.apply_gradients.
+ **kwargs: Additional keyword arguments for super.apply_gradients.
+
+ Returns:
+ An `Operation` that applies the specified gradients.
+ """
+ # In Python 3, grads_and_vars can be a zip() object which can only be
+ # iterated over once. By converting it to a list, we ensure that it can be
+ # iterated over more than once.
+ grads_and_vars = list(grads_and_vars)
+
+ # Compute step.
+ steps_and_vars = self._compute_update_steps(grads_and_vars)
+
+ # Update trainable variables with this step.
+ return super(KfacOptimizer, self).apply_gradients(steps_and_vars, *args,
+ **kwargs)
+
+ def _squared_fisher_norm(self, grads_and_vars, precon_grads_and_vars):
+ """Computes the squared (approximate) Fisher norm of the updates.
+
+ This is defined as v^T F v, where F is the approximate Fisher matrix
+ as computed by the estimator, and v = F^{-1} g, where g is the gradient.
+ This is computed efficiently as v^T g.
+
+ Args:
+ grads_and_vars: List of (gradient, variable) pairs.
+ precon_grads_and_vars: List of (preconditioned gradient, variable) pairs.
+ Must be the result of calling `self._fisher_est.multiply_inverse`
+ on `grads_and_vars`.
+
+ Returns:
+ Scalar representing the squared norm.
+
+ Raises:
+ ValueError: if the two list arguments do not contain the same variables,
+ in the same order.
+ """
+ for (_, gvar), (_, pgvar) in zip(grads_and_vars, precon_grads_and_vars):
+ if gvar is not pgvar:
+ raise ValueError("The variables referenced by the two arguments "
+ "must match.")
+ terms = [
+ math_ops.reduce_sum(grad * pgrad)
+ for (grad, _), (pgrad, _) in zip(grads_and_vars, precon_grads_and_vars)
+ ]
+ return math_ops.reduce_sum(terms)
+
+ def _update_clip_coeff(self, grads_and_vars, precon_grads_and_vars):
+ """Computes the scale factor for the update to satisfy the norm constraint.
+
+ Defined as min(1, sqrt(c / r^T F r)), where c is the norm constraint,
+ F is the approximate Fisher matrix, and r is the update vector, i.e.
+ -alpha * v, where alpha is the learning rate, and v is the preconditioned
+ gradient.
+
+ This is based on Section 5 of Ba et al., Distributed Second-Order
+ Optimization using Kronecker-Factored Approximations. Note that they
+ absorb the learning rate alpha (which they denote eta_max) into the formula
+ for the coefficient, while in our implementation, the rescaling is done
+ before multiplying by alpha. Hence, our formula differs from theirs by a
+ factor of alpha.
+
+ Args:
+ grads_and_vars: List of (gradient, variable) pairs.
+ precon_grads_and_vars: List of (preconditioned gradient, variable) pairs.
+ Must be the result of calling `self._fisher_est.multiply_inverse`
+ on `grads_and_vars`.
+
+ Returns:
+ Scalar representing the coefficient which should be applied to the
+ preconditioned gradients to satisfy the norm constraint.
+ """
+ sq_norm_grad = self._squared_fisher_norm(grads_and_vars,
+ precon_grads_and_vars)
+ sq_norm_up = sq_norm_grad * self._learning_rate**2
+ return math_ops.minimum(1.,
+ math_ops.sqrt(self._norm_constraint / sq_norm_up))
+
+ def _clip_updates(self, grads_and_vars, precon_grads_and_vars):
+ """Rescales the preconditioned gradients to satisfy the norm constraint.
+
+ Rescales the preconditioned gradients such that the resulting update r
+ (after multiplying by the learning rate) will satisfy the norm constraint.
+ This constraint is that r^T F r <= C, where F is the approximate Fisher
+ matrix, and C is the norm_constraint attribute. See Section 5 of
+ Ba et al., Distributed Second-Order Optimization using Kronecker-Factored
+ Approximations.
+
+ Args:
+ grads_and_vars: List of (gradient, variable) pairs.
+ precon_grads_and_vars: List of (preconditioned gradient, variable) pairs.
+ Must be the result of calling `self._fisher_est.multiply_inverse`
+ on `grads_and_vars`.
+
+ Returns:
+ List of (rescaled preconditioned gradient, variable) pairs.
+ """
+ coeff = self._update_clip_coeff(grads_and_vars, precon_grads_and_vars)
+ return [(pgrad * coeff, var) for pgrad, var in precon_grads_and_vars]
+
+ def _compute_prev_updates(self, variables):
+ """Computes previous updates as negative velocities scaled by learning rate.
+
+ Args:
+ variables: List of variables in the graph that the update will be
+ applied to.
+
+ Returns:
+ List of previous updates applied to the `variables`.
+ """
+ return list(
+ -1 * self._learning_rate * self._zeros_slot(var, "velocity", self._name)
+ for var in variables)
+
+ def _compute_qmodel_hyperparams(self, precon_grads, prev_updates, grads,
+ variables):
+ """Compute optimal update hyperparameters from the quadratic model.
+
+ More specifically, if L is the loss we minimize a quadratic approximation
+ of L(theta + d) which we denote by qmodel(d) with
+ d = alpha*precon_grad + mu*prev_update with respect to alpha and mu, where
+
+ qmodel(d) = (1/2) * d^T * B * d + grad^T*d + L(theta) .
+
+ Unlike in the KL clipping approach we use the non-approximated quadratic
+ model where the curvature matrix C is the true Fisher on the current
+ mini-batch (computed without any approximations beyond mini-batch sampling),
+ with the usual Tikhonov damping/regularization applied,
+
+ C = F + damping * I
+
+ See Section 7 of https://arxiv.org/abs/1503.05671 for a derivation of
+ the formula. See Appendix C for a discussion of the trick of using
+ a factorized Fisher matrix to more efficiently compute the required
+ vector-matrix-vector products.
+
+ Note that the elements of all 4 lists passed to this function must
+ be in correspondence with each other.
+
+ Args:
+ precon_grads: List of preconditioned gradients.
+ prev_updates: List of updates computed at the previous iteration.
+ grads: List of gradients.
+ variables: List of variables in the graph that the update will be
+ applied to. (Note that this function doesn't actually apply the
+ update.)
+
+ Returns:
+ (alpha, mu, qmodel_change), where alpha and mu are chosen to optimize the
+ quadratic model, and
+ qmodel_change = qmodel(alpha*precon_grad + mu*prev_update) - qmodel(0)
+ = qmodel(alpha*precon_grad + mu*prev_update) - L(theta).
+ """
+
+ cmvpc = cmvp.CurvatureMatrixVectorProductComputer(self._layers.losses,
+ variables)
+
+ # compute the matrix-vector products with the transposed Fisher factor
+ fft_precon_grads = cmvpc.multiply_fisher_factor_transpose(precon_grads)
+ fft_prev_updates = cmvpc.multiply_fisher_factor_transpose(prev_updates)
+ batch_size = math_ops.cast(
+ self._batch_size, dtype=fft_precon_grads[0].dtype)
+
+ # compute the entries of the 2x2 matrix
+ m_11 = (
+ _inner_product_list(fft_precon_grads, fft_precon_grads) / batch_size +
+ self.damping * _inner_product_list(precon_grads, precon_grads))
+
+ m_21 = (
+ _inner_product_list(fft_prev_updates, fft_precon_grads) / batch_size +
+ self.damping * _inner_product_list(prev_updates, precon_grads))
+
+ m_22 = (
+ _inner_product_list(fft_prev_updates, fft_prev_updates) / batch_size +
+ self.damping * _inner_product_list(prev_updates, prev_updates))
+
+ def non_zero_prevupd_case():
+ r"""Computes optimal (alpha, mu) given non-zero previous update.
+
+ We solve the full 2x2 linear system. See Martens & Grosse (2015),
+ Section 7, definition of $\alpha^*$ and $\mu^*$.
+
+ Returns:
+ (alpha, mu, qmodel_change), where alpha and mu are chosen to optimize
+ the quadratic model, and
+ qmodel_change = qmodel(alpha*precon_grad + mu*prev_update) - qmodel(0).
+ """
+ m = ops.convert_to_tensor([[m_11, m_21], [m_21, m_22]])
+
+ c = ops.convert_to_tensor([[_inner_product_list(grads, precon_grads)],
+ [_inner_product_list(grads, prev_updates)]])
+
+ sol = -1. * _two_by_two_solve(m, c)
+ alpha = sol[0]
+ mu = sol[1]
+ qmodel_change = 0.5 * math_ops.reduce_sum(sol * c)
+
+ return alpha, mu, qmodel_change
+
+ def zero_prevupd_case():
+ r"""Computes optimal (alpha, mu) given all-zero previous update.
+
+ The linear system reduces to 1x1. See Martens & Grosse (2015),
+ Section 6.4, definition of $\alpha^*$.
+
+ Returns:
+ (alpha, 0.0, qmodel_change), where alpha is chosen to optimize the
+ quadratic model, and
+ qmodel_change = qmodel(alpha*precon_grad) - qmodel(0)
+ """
+ m = m_11
+ c = _inner_product_list(grads, precon_grads)
+
+ alpha = -c / m
+ mu = 0.0
+ qmodel_change = 0.5 * alpha * c
+
+ return alpha, mu, qmodel_change
+
+ return control_flow_ops.cond(
+ math_ops.equal(m_22, 0.0), zero_prevupd_case, non_zero_prevupd_case)
+
+ def _assign_q_model_change(self, q_model_change):
+ """Assigns `q_model_change` to `self._q_model_change` if damping is adapted.
+
+ Note only the chief worker does the assignment.
+
+ Args:
+ q_model_change: Scalar tensor of type `float32`.
+
+ Returns:
+ If `adapt_damping` is `True` then returns an assign op, Otherwise returns
+ a no_op().
+ """
+ if self._adapt_damping and self._is_chief:
+ q_model_assign_op = state_ops.assign(self._q_model_change, q_model_change)
+ else:
+ q_model_assign_op = control_flow_ops.no_op()
+ return q_model_assign_op
+
+ def _compute_qmodel_hyperparams_wrapper(self, grads_and_vars,
+ precon_grads_and_vars):
+ """Wrapper function for `self._compute_qmodel_hyperparams`.
+
+ Constructs a list of preconditioned gradients and variables. Also creates a
+ op to assign the computed q model change to `self._q_model_change`.
+
+ Args:
+ grads_and_vars: List of (gradient, variable) pairs.
+ precon_grads_and_vars: List of (preconditioned gradients, variable)
+ pairs.
+
+ Returns:
+ (alpha, mu, q_model_assign_op), where alpha and mu are chosen to optimize
+ the quadratic model, `q_model_assign_op` assigns the computed q model
+ change to `self._q_model_change`.
+ """
+ precon_grads = list(
+ precon_grad for (precon_grad, _) in precon_grads_and_vars)
+ grads = list(grad for (grad, _) in grads_and_vars)
+ variables = list(var for (_, var) in grads_and_vars)
+ prev_updates = self._compute_prev_updates(variables)
+ # Compute optimal velocity update parameters according to quadratic model
+ alpha, mu, q_model_change = self._compute_qmodel_hyperparams(
+ precon_grads, prev_updates, grads, variables)
+
+ return alpha, mu, self._assign_q_model_change(q_model_change)
+
+ def _compute_update_steps(self, grads_and_vars):
+ """Computes the update steps for the variables given the gradients.
+
+ Args:
+ grads_and_vars: List of (gradient, variable) pairs.
+
+ Returns:
+ A list of tuple (assign_op ,var) where `assign_op` assigns the update
+ steps to `var`.
+ """
+
+ if self._momentum_type == "regular":
+ # Compute "preconditioned" gradient.
+ precon_grads_and_vars = self._fisher_est.multiply_inverse(grads_and_vars)
+
+ # Apply "KL clipping" if asked for.
+ if self._norm_constraint is not None:
+ precon_grads_and_vars = self._clip_updates(grads_and_vars,
+ precon_grads_and_vars)
+
+ # Update the velocity with this and return it as the step.
+ if self._adapt_damping and self._is_chief:
+ _, _, q_model_assign_op = self._compute_qmodel_hyperparams_wrapper(
+ grads_and_vars, precon_grads_and_vars)
+ with ops.control_dependencies([q_model_assign_op]):
+ return self._update_velocities(precon_grads_and_vars, self._momentum)
+ else:
+ return self._update_velocities(precon_grads_and_vars, self._momentum)
+ elif self._momentum_type == "adam":
+ # Update velocity.
+ velocities_and_vars = self._update_velocities(grads_and_vars,
+ self._momentum)
+ # Return "preconditioned" velocity vector as the step.
+ return self._fisher_est.multiply_inverse(velocities_and_vars)
+
+ elif self._momentum_type == "qmodel":
+ # Compute "preconditioned" gradient.
+ precon_grads_and_vars = self._fisher_est.multiply_inverse(grads_and_vars)
+
+ # Compute optimal velocity update parameters according to quadratic model
+ alpha, mu, q_model_assign_op = self._compute_qmodel_hyperparams_wrapper(
+ grads_and_vars, precon_grads_and_vars)
+
+ with ops.control_dependencies([q_model_assign_op]):
+ return self._update_velocities(
+ precon_grads_and_vars, mu, vec_coeff=-alpha)
+
+ def _update_velocities(self, vecs_and_vars, decay, vec_coeff=1.0):
+ """Updates the velocities of the variables with the given vectors.
+
+ Args:
+ vecs_and_vars: List of (vector, variable) pairs.
+ decay: How much to decay the old velocity by. This is often referred to
+ as the 'momentum constant'.
+ vec_coeff: Coefficient to apply to the vectors before adding them to the
+ velocity.
+
+ Returns:
+ A list of (velocity, var) indicating the new velocity for each var.
+ """
+
+ def _update_velocity(vec, var):
+ velocity = self._zeros_slot(var, "velocity", self._name)
+ with ops.colocate_with(velocity):
+ # NOTE(mattjj): read/modify/write race condition not suitable for async.
+
+ # Compute the new velocity for this variable.
+ new_velocity = decay * velocity + vec_coeff * vec
+
+ # Save the updated velocity.
+ return (array_ops.identity(velocity.assign(new_velocity)), var)
+
+ # Go through variable and update its associated part of the velocity vector.
+ return [_update_velocity(vec, var) for vec, var in vecs_and_vars]
+
+ def _update_damping(self, prev_batch, global_step):
+ """Adapts damping parameter. Check KFAC (Section 6.5) for the details.
+
+ The damping parameter is updated according to the Levenberg-Marquardt rule
+ every `self._damping_adaptation_interval` iterations.
+
+ Args:
+ prev_batch: Tensor or tuple of tensors which can be passed to
+ `self._loss_fn` to evaluate loss.
+ global_step: `Variable` which keeps track of number of times the training
+ variables have been updated.
+ Returns:
+ A `tf.cond` op which updates the damping parameter.
+ """
+ def compute_damping():
+ """"Adapts damping parameter based on "reduction ratio".
+
+ Reduction ratio captures how closely the quadratic approximation to the
+ loss function approximates the actual loss within a trust region. The
+ damping update tries to make the damping as small as possible while
+ maintaining the property that the quadratic model remains a good local
+ approximation to the loss function.
+
+ Returns:
+ An Op to assign newly computed damping value to `self._damping`.
+ """
+ prev_batch_loss = self._loss_fn(prev_batch)
+ with ops.control_dependencies([prev_batch_loss]):
+ rho_assign = self._rho.assign(
+ (prev_batch_loss - self._prev_loss) / self._q_model_change)
+ with ops.control_dependencies([rho_assign]):
+ new_damping = control_flow_ops.case(
+ [(self._rho < 0.25, lambda: self.damping / self._omega),
+ (self._rho > 0.75, lambda: self.damping * self._omega)],
+ lambda: self.damping)
+ with ops.control_dependencies([new_damping]):
+ new_damping_min = math_ops.maximum(new_damping, self._min_damping)
+ return control_flow_ops.group(self._damping.assign(new_damping_min))
+
+ return control_flow_ops.cond(
+ math_ops.equal(
+ math_ops.mod(global_step + 1, self._damping_adaptation_interval),
+ 0), compute_damping, control_flow_ops.no_op)
+
+
+def _inner_product_list(list1, list2):
+ return math_ops.add_n(
+ [math_ops.reduce_sum(elt1 * elt2) for elt1, elt2 in zip(list1, list2)])
+
+
+def _two_by_two_solve(m, c):
+ # it might be better just to crank out the exact formula for 2x2 inverses
+ return math_ops.matmul(linalg_ops.matrix_inverse(m), c)
diff --git a/tensorflow/contrib/kfac/python/ops/optimizer_lib.py b/tensorflow/contrib/kfac/python/ops/optimizer_lib.py
new file mode 100644
index 0000000000..87d1866e06
--- /dev/null
+++ b/tensorflow/contrib/kfac/python/ops/optimizer_lib.py
@@ -0,0 +1,30 @@
+# Copyright 2017 The TensorFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+"""The KFAC optimizer."""
+
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+# pylint: disable=unused-import,line-too-long,wildcard-import
+from tensorflow.contrib.kfac.python.ops.optimizer import *
+from tensorflow.python.util.all_util import remove_undocumented
+# pylint: enable=unused-import,line-too-long,wildcard-import
+
+_allowed_symbols = [
+ "KfacOptimizer",
+]
+
+remove_undocumented(__name__, allowed_exception_list=_allowed_symbols)
diff --git a/tensorflow/contrib/kfac/python/ops/placement.py b/tensorflow/contrib/kfac/python/ops/placement.py
new file mode 100644
index 0000000000..c4454325ae
--- /dev/null
+++ b/tensorflow/contrib/kfac/python/ops/placement.py
@@ -0,0 +1,114 @@
+# Copyright 2017 The TensorFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+"""Implements placement strategies for cov and inv ops, cov variables."""
+
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+import itertools
+
+from tensorflow.python.framework import ops as tf_ops
+
+
+def _make_thunk_on_device(func, device):
+ def thunk():
+ with tf_ops.device(device):
+ return func()
+ return thunk
+
+
+class RoundRobinPlacementMixin(object):
+ """Implements round robin placement strategy for ops and variables."""
+
+ def __init__(self, cov_devices=None, inv_devices=None, **kwargs):
+ """Initializes the RoundRobinPlacementMixin class.
+
+ Args:
+ cov_devices: Iterable of device strings (e.g. '/gpu:0'). Covariance
+ computations will be placed on these devices in a round-robin fashion.
+ Can be None, which means that no devices are specified.
+ inv_devices: Iterable of device strings (e.g. '/gpu:0'). Inversion
+ computations will be placed on these devices in a round-robin fashion.
+ Can be None, which means that no devices are specified.
+ **kwargs: Need something here?
+
+ """
+ super(RoundRobinPlacementMixin, self).__init__(**kwargs)
+ self._cov_devices = cov_devices
+ self._inv_devices = inv_devices
+
+ def make_vars_and_create_op_thunks(self, scope=None):
+ """Make vars and create op thunks w/ a round-robin device placement start.
+
+ For each factor, all of that factor's cov variables and their associated
+ update ops will be placed on a particular device. A new device is chosen
+ for each factor by cycling through list of devices in the
+ `self._cov_devices` attribute. If `self._cov_devices` is `Non`e then no
+ explicit device placement occurs.
+
+ An analogous strategy is followed for inverse update ops, with the list of
+ devices being given by the `self._inv_devices` attribute.
+
+ Inverse variables on the other hand are not placed on any specific device
+ (they will just use the current the device placement context, whatever
+ that happens to be). The idea is that the inverse variable belong where
+ they will be accessed most often, which is the device that actually applies
+ the preconditioner to the gradient. The user will be responsible for setting
+ the device context for this.
+
+ Args:
+ scope: A string or None. If None it will be set to the name of this
+ estimator (given by the name property). All variables will be created,
+ and all thunks will execute, inside of a variable scope of the given
+ name. (Default: None)
+
+ Returns:
+ cov_update_thunks: List of cov update thunks. Corresponds one-to-one with
+ the list of factors given by the "factors" property.
+ inv_update_thunks: List of inv update thunks. Corresponds one-to-one with
+ the list of factors given by the "factors" property.
+ """
+ # Note: `create_ops_and_vars_thunks` is implemented in `FisherEstimator`.
+ (cov_variable_thunks_raw, cov_update_thunks_raw, inv_variable_thunks_raw,
+ inv_update_thunks_raw) = self.create_ops_and_vars_thunks(scope=scope)
+
+ if self._cov_devices:
+ cov_update_thunks = []
+ for cov_variable_thunk, cov_update_thunk, device in zip(
+ cov_variable_thunks_raw, cov_update_thunks_raw,
+ itertools.cycle(self._cov_devices)):
+ with tf_ops.device(device):
+ cov_variable_thunk()
+ cov_update_thunks.append(_make_thunk_on_device(cov_update_thunk,
+ device))
+ else:
+ for cov_variable_thunk in cov_variable_thunks_raw:
+ cov_variable_thunk()
+ cov_update_thunks = cov_update_thunks_raw
+
+ for inv_variable_thunk in inv_variable_thunks_raw:
+ inv_variable_thunk()
+
+ if self._inv_devices:
+ inv_update_thunks = []
+ for inv_update_thunk, device in zip(inv_update_thunks_raw,
+ itertools.cycle(self._inv_devices)):
+ inv_update_thunks.append(_make_thunk_on_device(inv_update_thunk,
+ device))
+ else:
+ inv_update_thunks = inv_update_thunks_raw
+
+ return cov_update_thunks, inv_update_thunks
diff --git a/tensorflow/contrib/kfac/python/ops/utils.py b/tensorflow/contrib/kfac/python/ops/utils.py
new file mode 100644
index 0000000000..144295f4c7
--- /dev/null
+++ b/tensorflow/contrib/kfac/python/ops/utils.py
@@ -0,0 +1,709 @@
+# Copyright 2017 The TensorFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+"""Utility functions."""
+
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+import numpy as np
+
+from tensorflow.contrib.tpu.python.ops import tpu_ops
+from tensorflow.contrib.tpu.python.tpu import tpu_function
+from tensorflow.python.framework import dtypes
+from tensorflow.python.framework import ops
+from tensorflow.python.framework import tensor_shape
+from tensorflow.python.ops import array_ops
+from tensorflow.python.ops import control_flow_ops
+from tensorflow.python.ops import gradients_impl
+from tensorflow.python.ops import linalg_ops
+from tensorflow.python.ops import math_ops
+from tensorflow.python.ops import nn_ops
+from tensorflow.python.ops import random_ops
+from tensorflow.python.ops import resource_variable_ops
+from tensorflow.python.ops import variables
+
+# Method used for inverting matrices.
+POSDEF_INV_METHOD = "cholesky"
+POSDEF_EIG_METHOD = "self_adjoint"
+
+
+def set_global_constants(posdef_inv_method=None):
+ """Sets various global constants used by the classes in this module."""
+ global POSDEF_INV_METHOD
+
+ if posdef_inv_method is not None:
+ POSDEF_INV_METHOD = posdef_inv_method
+
+
+class SequenceDict(object):
+ """A dict convenience wrapper that allows getting/setting with sequences."""
+
+ def __init__(self, iterable=None):
+ self._dict = dict(iterable or [])
+
+ def __getitem__(self, key_or_keys):
+ if isinstance(key_or_keys, (tuple, list)):
+ return list(map(self.__getitem__, key_or_keys))
+ else:
+ return self._dict[key_or_keys]
+
+ def __setitem__(self, key_or_keys, val_or_vals):
+ if isinstance(key_or_keys, (tuple, list)):
+ for key, value in zip(key_or_keys, val_or_vals):
+ self[key] = value
+ else:
+ self._dict[key_or_keys] = val_or_vals
+
+ def items(self):
+ return list(self._dict.items())
+
+
+def tensors_to_column(tensors):
+ """Converts a tensor or list of tensors to a column vector.
+
+ Args:
+ tensors: A tensor or list of tensors.
+
+ Returns:
+ The tensors reshaped into vectors and stacked on top of each other.
+ """
+ if isinstance(tensors, (tuple, list)):
+ return array_ops.concat(
+ tuple(array_ops.reshape(tensor, [-1, 1]) for tensor in tensors), axis=0)
+ else:
+ return array_ops.reshape(tensors, [-1, 1])
+
+
+def column_to_tensors(tensors_template, colvec):
+ """Converts a column vector back to the shape of the given template.
+
+ Args:
+ tensors_template: A tensor or list of tensors.
+ colvec: A 2d column vector with the same shape as the value of
+ tensors_to_column(tensors_template).
+
+ Returns:
+ X, where X is tensor or list of tensors with the properties:
+ 1) tensors_to_column(X) = colvec
+ 2) X (or its elements) have the same shape as tensors_template (or its
+ elements)
+ """
+ if isinstance(tensors_template, (tuple, list)):
+ offset = 0
+ tensors = []
+ for tensor_template in tensors_template:
+ sz = np.prod(tensor_template.shape.as_list(), dtype=np.int32)
+ tensor = array_ops.reshape(colvec[offset:(offset + sz)],
+ tensor_template.shape)
+ tensors.append(tensor)
+ offset += sz
+
+ tensors = tuple(tensors)
+ else:
+ tensors = array_ops.reshape(colvec, tensors_template.shape)
+
+ return tensors
+
+
+def kronecker_product(mat1, mat2):
+ """Computes the Kronecker product two matrices."""
+ m1, n1 = mat1.get_shape().as_list()
+ mat1_rsh = array_ops.reshape(mat1, [m1, 1, n1, 1])
+ m2, n2 = mat2.get_shape().as_list()
+ mat2_rsh = array_ops.reshape(mat2, [1, m2, 1, n2])
+ return array_ops.reshape(mat1_rsh * mat2_rsh, [m1 * m2, n1 * n2])
+
+
+def layer_params_to_mat2d(vector):
+ """Converts a vector shaped like layer parameters to a 2D matrix.
+
+ In particular, we reshape the weights/filter component of the vector to be
+ 2D, flattening all leading (input) dimensions. If there is a bias component,
+ we concatenate it to the reshaped weights/filter component.
+
+ Args:
+ vector: A Tensor or pair of Tensors shaped like layer parameters.
+
+ Returns:
+ A 2D Tensor with the same coefficients and the same output dimension.
+ """
+ if isinstance(vector, (tuple, list)):
+ w_part, b_part = vector
+ w_part_reshaped = array_ops.reshape(w_part,
+ [-1, w_part.shape.as_list()[-1]])
+ return array_ops.concat(
+ (w_part_reshaped, array_ops.reshape(b_part, [1, -1])), axis=0)
+ elif isinstance(vector, ops.IndexedSlices):
+ return vector
+ else: # Tensor or Tensor-like.
+ return array_ops.reshape(vector, [-1, vector.shape.as_list()[-1]])
+
+
+def mat2d_to_layer_params(vector_template, mat2d):
+ """Converts a canonical 2D matrix representation back to a vector.
+
+ Args:
+ vector_template: A Tensor or pair of Tensors shaped like layer parameters.
+ mat2d: A 2D Tensor with the same shape as the value of
+ layer_params_to_mat2d(vector_template).
+
+ Returns:
+ A Tensor or pair of Tensors with the same coefficients as mat2d and the same
+ shape as vector_template.
+ """
+ if isinstance(vector_template, (tuple, list)):
+ w_part, b_part = mat2d[:-1], mat2d[-1]
+ return array_ops.reshape(w_part, vector_template[0].shape), b_part
+ elif isinstance(vector_template, ops.IndexedSlices):
+ if not isinstance(mat2d, ops.IndexedSlices):
+ raise TypeError(
+ "If vector_template is an IndexedSlices, so should mat2d.")
+ return mat2d
+ else:
+ return array_ops.reshape(mat2d, vector_template.shape)
+
+
+def posdef_inv(tensor, damping):
+ """Computes the inverse of tensor + damping * identity."""
+ identity = linalg_ops.eye(tensor.shape.as_list()[0], dtype=tensor.dtype)
+ damping = math_ops.cast(damping, dtype=tensor.dtype)
+ return posdef_inv_functions[POSDEF_INV_METHOD](tensor, identity, damping)
+
+
+def posdef_inv_matrix_inverse(tensor, identity, damping):
+ """Computes inverse(tensor + damping * identity) directly."""
+ return linalg_ops.matrix_inverse(tensor + damping * identity)
+
+
+def posdef_inv_cholesky(tensor, identity, damping):
+ """Computes inverse(tensor + damping * identity) with Cholesky."""
+ chol = linalg_ops.cholesky(tensor + damping * identity)
+ return linalg_ops.cholesky_solve(chol, identity)
+
+
+def posdef_inv_eig(tensor, identity, damping):
+ """Computes inverse(tensor + damping * identity) with eigendecomposition."""
+ eigenvalues, eigenvectors = linalg_ops.self_adjoint_eig(
+ tensor + damping * identity)
+ return math_ops.matmul(
+ eigenvectors / eigenvalues, eigenvectors, transpose_b=True)
+
+
+posdef_inv_functions = {
+ "matrix_inverse": posdef_inv_matrix_inverse,
+ "cholesky": posdef_inv_cholesky,
+ "eig": posdef_inv_eig,
+}
+
+
+def posdef_eig(mat):
+ """Computes the eigendecomposition of a positive semidefinite matrix."""
+ return posdef_eig_functions[POSDEF_EIG_METHOD](mat)
+
+
+def posdef_eig_svd(mat):
+ """Computes the singular values and left singular vectors of a matrix."""
+ evals, evecs, _ = linalg_ops.svd(mat)
+
+ return evals, evecs
+
+
+def posdef_eig_self_adjoint(mat):
+ """Computes eigendecomposition using self_adjoint_eig."""
+ evals, evecs = linalg_ops.self_adjoint_eig(mat)
+ evals = math_ops.abs(evals) # Should be equivalent to svd approach.
+
+ return evals, evecs
+
+
+posdef_eig_functions = {
+ "self_adjoint": posdef_eig_self_adjoint,
+ "svd": posdef_eig_svd,
+}
+
+
+def cholesky(tensor, damping):
+ """Computes the inverse of tensor + damping * identity."""
+ identity = linalg_ops.eye(tensor.shape.as_list()[0], dtype=tensor.dtype)
+ damping = math_ops.cast(damping, dtype=tensor.dtype)
+ return linalg_ops.cholesky(tensor + damping * identity)
+
+
+class SubGraph(object):
+ """Defines a subgraph given by all the dependencies of a given set of outputs.
+ """
+
+ def __init__(self, outputs):
+ # Set of all ancestor Tensors, Ops to 'outputs'.
+ self._members = set()
+
+ self._iter_add(outputs)
+
+ def _iter_add(self, root):
+ """Iteratively adds all of nodes' ancestors using depth first search."""
+ stack = [root]
+ while stack:
+ nodes = stack.pop()
+ for node in nodes:
+ if node in self._members:
+ continue
+ self._members.add(node)
+
+ if isinstance(node, ops.Tensor):
+ stack.append((node.op,))
+ elif isinstance(node, ops.Operation):
+ stack.append(node.inputs)
+
+ def is_member(self, node):
+ """Check if 'node' is in this subgraph."""
+ return node in self._members
+
+ def variable_uses(self, var):
+ """Computes number of times a variable is used.
+
+ Args:
+ var: Variable or ResourceVariable instance.
+
+ Returns:
+ Number of times a variable is used within this subgraph.
+
+ Raises:
+ ValueError: If 'var' is not a variable type.
+ """
+ if isinstance(var, resource_variable_ops.ResourceVariable):
+ var = var.handle
+ elif isinstance(var, variables.Variable):
+ var = var.value()
+ else:
+ raise ValueError("%s does not appear to be a variable." % str(var))
+
+ return len(self._members.intersection(set(var.consumers())))
+
+ def filter_list(self, node_list):
+ """Filters 'node_list' to nodes in this subgraph."""
+ filtered_list = []
+ for node in node_list:
+ if self.is_member(node):
+ filtered_list.append(node)
+ return filtered_list
+
+
+def generate_random_signs(shape, dtype=dtypes.float32):
+ """Generate a random tensor with {-1, +1} entries."""
+ ints = random_ops.random_uniform(shape, maxval=2, dtype=dtypes.int32)
+ return 2 * math_ops.cast(ints, dtype=dtype) - 1
+
+
+def fwd_gradients(ys, xs, grad_xs=None, stop_gradients=None):
+ """Compute forward-mode gradients."""
+ # See b/37888268.
+
+ # This version of forward-mode autodiff is based on code by Tim Cooijmans
+ # and handles list arguments and certain special cases such as when the
+ # ys doesn't depend on one or more of the xs, and when ops.IndexedSlices are
+ # generated by the first gradients_impl.gradients call.
+
+ us = [array_ops.zeros_like(y) + float("nan") for y in ys]
+ dydxs = gradients_impl.gradients(
+ ys, xs, grad_ys=us, stop_gradients=stop_gradients)
+
+ # Deal with strange types that gradients_impl.gradients returns but can't
+ # deal with.
+ dydxs = [
+ ops.convert_to_tensor(dydx)
+ if isinstance(dydx, ops.IndexedSlices) else dydx for dydx in dydxs
+ ]
+ dydxs = [
+ array_ops.zeros_like(x) if dydx is None else dydx
+ for x, dydx in zip(xs, dydxs)
+ ]
+
+ dysdx = gradients_impl.gradients(dydxs, us, grad_ys=grad_xs)
+
+ return dysdx
+
+
+def on_tpu():
+ """Returns True when building a TPU computation."""
+ return tpu_function.get_tpu_context().number_of_shards is not None
+
+
+def cross_replica_mean(tensor, name=None):
+ """Takes mean value of a Tensor across all TPU cores.
+
+ Args:
+ tensor: Tensor to be synchronized.
+ name: None or string. Name of Op.
+
+ Returns:
+ Average of Tensor across all TPU cores.
+
+ Raises:
+ ValueError: If called outside of TPU context.
+ """
+ with ops.name_scope(name, "cross_replica_mean", [tensor]):
+ num_shards = tpu_function.get_tpu_context().number_of_shards
+ if num_shards is None:
+ raise ValueError(
+ "Cannot take cross_replica_mean() outside of TPU Context.")
+ if num_shards == 1:
+ return tensor
+ return tpu_ops.cross_replica_sum(tensor / num_shards)
+
+
+def ensure_sequence(obj):
+ """If `obj` isn't a tuple or list, return a tuple containing `obj`."""
+ if isinstance(obj, (tuple, list)):
+ return obj
+ else:
+ return (obj,)
+
+
+def batch_execute(global_step, thunks, batch_size, name=None):
+ """Executes a subset of ops per global step.
+
+ Given a list of thunks, each of which produces a single stateful op,
+ ensures that exactly 'batch_size' ops are run per global step. Ops are
+ scheduled in a round-robin fashion. For example, with 3 ops
+
+ global_step | op0 | op1 | op2
+ ------------+-----+-----+-----
+ 0 | x | x |
+ ------------+-----+-----+-----
+ 1 | x | | x
+ ------------+-----+-----+-----
+ 2 | | x | x
+ ------------+-----+-----+-----
+ 3 | x | x |
+ ------------+-----+-----+-----
+ 4 | x | | x
+
+ Does not guarantee order of op execution within a single global step.
+
+ Args:
+ global_step: Tensor indicating time. Determines which ops run.
+ thunks: List of thunks. Each thunk encapsulates one op. Return values are
+ ignored.
+ batch_size: int. Number of ops to execute per global_step.
+ name: string or None. Name scope for newly added ops.
+
+ Returns:
+ List of ops. Exactly 'batch_size' ops are guaranteed to have an effect
+ every global step.
+ """
+
+ def true_fn(thunk):
+ """Ensures thunk is executed and returns an Op (not a Tensor)."""
+
+ def result():
+ with ops.control_dependencies([thunk()]):
+ return control_flow_ops.no_op()
+
+ return result
+
+ def false_fn(_):
+ """Executes a no-op."""
+
+ def result():
+ return control_flow_ops.no_op()
+
+ return result
+
+ with ops.name_scope(name, "batch_execute"):
+ true_fns = [true_fn(thunk) for thunk in thunks]
+ false_fns = [false_fn(thunk) for thunk in thunks]
+ num_thunks = len(thunks)
+ conditions = [
+ math_ops.less(
+ math_ops.mod(batch_size - 1 + global_step * batch_size - j,
+ num_thunks), batch_size) for j in range(num_thunks)
+ ]
+ result = [
+ control_flow_ops.cond(condition, true_fn, false_fn)
+ for (condition, true_fn,
+ false_fn) in zip(conditions, true_fns, false_fns)
+ ]
+ return result
+
+
+def extract_convolution_patches(inputs,
+ filter_shape,
+ padding,
+ strides=None,
+ dilation_rate=None,
+ name=None,
+ data_format=None):
+ """Extracts inputs to each output coordinate in tf.nn.convolution.
+
+ This is a generalization of tf.extract_image_patches() to tf.nn.convolution(),
+ where the number of spatial dimensions may be something other than 2.
+
+ Assumes,
+ - First dimension of inputs is batch_size
+ - Convolution filter is applied to all input channels.
+
+ Args:
+ inputs: Tensor of shape [batch_size, ..spatial_image_shape..,
+ ..spatial_filter_shape.., in_channels]. Inputs to tf.nn.convolution().
+ filter_shape: List of ints. Shape of filter passed to tf.nn.convolution().
+ padding: string. Padding method. One of "VALID", "SAME".
+ strides: None or list of ints. Strides along spatial dimensions.
+ dilation_rate: None or list of ints. Dilation along spatial dimensions.
+ name: None or str. Name of Op.
+ data_format: None or str. Format of data.
+
+ Returns:
+ Tensor of shape [batch_size, ..spatial_image_shape..,
+ ..spatial_filter_shape.., in_channels]
+
+ Raises:
+ ValueError: If data_format does not put channel last.
+ ValueError: If inputs and filter disagree on in_channels.
+ """
+ if not is_data_format_channel_last(data_format):
+ raise ValueError("Channel must be last dimension.")
+ with ops.name_scope(name, "extract_convolution_patches",
+ [inputs, filter_shape, padding, strides, dilation_rate]):
+ batch_size = inputs.shape.as_list()[0]
+ in_channels = inputs.shape.as_list()[-1]
+
+ # filter_shape = spatial_filter_shape + [in_channels, out_channels]
+ spatial_filter_shape = filter_shape[:-2]
+ if in_channels != filter_shape[-2]:
+ raise ValueError("inputs and filter_shape must agree on in_channels.")
+
+ # Map each input feature to a location in the output.
+ out_channels = np.prod(spatial_filter_shape) * in_channels
+ filters = linalg_ops.eye(out_channels)
+ filters = array_ops.reshape(
+ filters,
+ list(spatial_filter_shape) + [in_channels, out_channels])
+
+ result = nn_ops.convolution(
+ inputs,
+ filters,
+ padding=padding,
+ strides=strides,
+ dilation_rate=dilation_rate)
+ spatial_output_shape = result.shape.as_list()[1:-1]
+ result = array_ops.reshape(result,
+ [batch_size or -1] + spatial_output_shape +
+ list(spatial_filter_shape) + [in_channels])
+
+ return result
+
+
+def extract_pointwise_conv2d_patches(inputs,
+ filter_shape,
+ name=None,
+ data_format=None):
+ """Extract patches for a 1x1 conv2d.
+
+ Args:
+ inputs: 4-D Tensor of shape [batch_size, height, width, in_channels].
+ filter_shape: List of 4 ints. Shape of filter to apply with conv2d()
+ name: None or str. Name for Op.
+ data_format: None or str. Format for data. See 'data_format' in
+ tf.nn.conv2d() for details.
+
+ Returns:
+ Tensor of shape [batch_size, ..spatial_input_shape..,
+ ..spatial_filter_shape.., in_channels]
+
+ Raises:
+ ValueError: if inputs is not 4-D.
+ ValueError: if filter_shape is not [1, 1, ?, ?]
+ ValueError: if data_format is not channels-last.
+ """
+ if inputs.shape.ndims != 4:
+ raise ValueError("inputs must have 4 dims.")
+ if len(filter_shape) != 4:
+ raise ValueError("filter_shape must have 4 dims.")
+ if filter_shape[0] != 1 or filter_shape[1] != 1:
+ raise ValueError("filter_shape must have shape 1 along spatial dimensions.")
+ if not is_data_format_channel_last(data_format):
+ raise ValueError("data_format must be channels last.")
+ with ops.name_scope(name, "extract_pointwise_conv2d_patches",
+ [inputs, filter_shape]):
+ ksizes = [1, 1, 1, 1] # Spatial shape is 1x1.
+ strides = [1, 1, 1, 1] # Operate on all pixels.
+ rates = [1, 1, 1, 1] # Dilation has no meaning with spatial shape = 1.
+ padding = "VALID" # Doesn't matter.
+ result = array_ops.extract_image_patches(inputs, ksizes, strides, rates,
+ padding)
+
+ batch_size, input_height, input_width, in_channels = inputs.shape.as_list()
+ filter_height, filter_width, in_channels, _ = filter_shape
+ return array_ops.reshape(result, [
+ batch_size, input_height, input_width, filter_height, filter_width,
+ in_channels
+ ])
+
+
+def is_data_format_channel_last(data_format):
+ """True if data_format puts channel last."""
+ if data_format is None:
+ return True
+ return data_format.endswith("C")
+
+
+def matmul_sparse_dense(A, B, name=None, transpose_a=False, transpose_b=False): # pylint: disable=invalid-name
+ """Computes matmul(A, B) where A is sparse, B is dense.
+
+ Args:
+ A: tf.IndexedSlices with dense shape [m, n].
+ B: tf.Tensor with shape [n, k].
+ name: str. Name of op.
+ transpose_a: Bool. If true we transpose A before multiplying it by B.
+ (Default: False)
+ transpose_b: Bool. If true we transpose B before multiplying it by A.
+ (Default: False)
+
+ Returns:
+ tf.IndexedSlices resulting from matmul(A, B).
+
+ Raises:
+ ValueError: If A doesn't represent a matrix.
+ ValueError: If B is not rank-2.
+ """
+ with ops.name_scope(name, "matmul_sparse_dense", [A, B]):
+ if A.indices.shape.ndims != 1 or A.values.shape.ndims != 2:
+ raise ValueError("A must represent a matrix. Found: %s." % A)
+ if B.shape.ndims != 2:
+ raise ValueError("B must be a matrix.")
+ new_values = math_ops.matmul(
+ A.values, B, transpose_a=transpose_a, transpose_b=transpose_b)
+ return ops.IndexedSlices(
+ new_values,
+ A.indices,
+ dense_shape=array_ops.stack([A.dense_shape[0], new_values.shape[1]]))
+
+
+def matmul_diag_sparse(A_diag, B, name=None): # pylint: disable=invalid-name
+ """Computes matmul(A, B) where A is a diagonal matrix, B is sparse.
+
+ Args:
+ A_diag: diagonal entries of matrix A of shape [m, m].
+ B: tf.IndexedSlices. Represents matrix of shape [m, n].
+ name: str. Name of op.
+
+ Returns:
+ tf.IndexedSlices resulting from matmul(A, B).
+
+ Raises:
+ ValueError: If A_diag is not rank-1.
+ ValueError: If B doesn't represent a matrix.
+ """
+ with ops.name_scope(name, "matmul_diag_sparse", [A_diag, B]):
+ A_diag = ops.convert_to_tensor(A_diag)
+ if A_diag.shape.ndims != 1:
+ raise ValueError("A_diag must be a rank-1 Tensor.")
+ if B.indices.shape.ndims != 1 or B.values.shape.ndims != 2:
+ raise ValueError("B must represent a matrix. Found: %s." % B)
+ a = array_ops.gather(A_diag, B.indices)
+ a = array_ops.reshape(a, list(a.shape) + [1] * (B.values.shape.ndims - 1))
+ return ops.IndexedSlices(a * B.values, B.indices, dense_shape=B.dense_shape)
+
+
+class PartitionedTensor(object):
+ """A Tensor partitioned across its 0-th dimension."""
+
+ def __init__(self, tensors):
+ """Initializes PartitionedTensor.
+
+ Args:
+ tensors: List of Tensors. All Tensors must agree on shape (excepting
+ batch dimension) and dtype.
+
+ Raises:
+ ValueError: If 'tensors' has length zero.
+ ValueError: if contents of 'tensors' don't agree on shape or dtype.
+ """
+ if not tensors:
+ raise ValueError("tensors must be a list of 1+ Tensors.")
+
+ dtype = tensors[0].dtype
+ if not all(tensor.dtype == dtype for tensor in tensors):
+ raise ValueError("all tensors must have dtype = %s." % dtype)
+
+ shape = tensors[0].shape[1:]
+ if not all(tensor.shape[1:] == shape for tensor in tensors):
+ raise ValueError("All tensors must have shape = %s (excluding batch "
+ "dimension)." % shape)
+
+ self.tensors = tensors
+ self._concats = {} # {device: Tensor}
+
+ @property
+ def shape(self):
+ feature_shape = self.tensors[0].shape[1:]
+ batch_size = sum([tensor.shape[0] for tensor in self.tensors],
+ tensor_shape.Dimension(0))
+ return tensor_shape.TensorShape([batch_size]).concatenate(feature_shape)
+
+ def get_shape(self):
+ return self.shape
+
+ @property
+ def dtype(self):
+ return self.tensors[0].dtype
+
+ def __str__(self):
+ return "PartitionedTensor([%s, ...], dtype=%s, shape=%s)" % (
+ self.tensors[0].name, self.dtype.name, tuple(self.shape.as_list()))
+
+ def __hash__(self):
+ return hash(tuple(self.tensors))
+
+ def __eq__(self, other):
+ if not isinstance(other, PartitionedTensor):
+ return False
+ return self.tensors == other.tensors
+
+ def __ne__(self, other):
+ return not self == other # pylint: disable=g-comparison-negation
+
+ def __getitem__(self, key):
+ return self.as_tensor()[key]
+
+ def as_tensor(self, dtype=None, name=None, as_ref=False):
+ with ops.name_scope(name, "PartitionedTensor.as_tensor", self.tensors):
+ assert not as_ref
+ assert dtype in [None, self.dtype]
+ result = array_ops.concat(self.tensors, axis=0)
+
+ # Cache 'result' if we haven't already cached a value for this device.
+ if result.device not in self._concats:
+ self._concats[result.device] = result
+ return self._concats[result.device]
+
+ @property
+ def device(self):
+ # PartitionedTensors in general do not live on a single device. If the
+ # device cannot be determined unambiguously this property will return None.
+ device = self.tensors[0].device
+ if all(tensor.device == device for tensor in self.tensors):
+ return device
+ return None
+
+
+ops.register_tensor_conversion_function(
+ PartitionedTensor,
+ lambda val, dtype, name, as_ref: val.as_tensor(dtype, name, as_ref))
+
+
+# TODO(b/69623235): Add a function for finding tensors that share gradients
+# to eliminate redundant fisher factor computations.
diff --git a/tensorflow/contrib/kfac/python/ops/utils_lib.py b/tensorflow/contrib/kfac/python/ops/utils_lib.py
new file mode 100644
index 0000000000..330d222dbf
--- /dev/null
+++ b/tensorflow/contrib/kfac/python/ops/utils_lib.py
@@ -0,0 +1,50 @@
+# Copyright 2017 The TensorFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+"""Utility functions."""
+
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+# pylint: disable=unused-import,line-too-long,wildcard-import
+from tensorflow.contrib.kfac.python.ops.utils import *
+from tensorflow.python.util.all_util import remove_undocumented
+# pylint: enable=unused-import,line-too-long,wildcard-import
+
+_allowed_symbols = [
+ "set_global_constants",
+ "SequenceDict",
+ "tensors_to_column",
+ "column_to_tensors",
+ "kronecker_product",
+ "layer_params_to_mat2d",
+ "mat2d_to_layer_params",
+ "posdef_inv",
+ "posdef_inv_matrix_inverse",
+ "posdef_inv_cholesky",
+ "posdef_inv_funcs",
+ "SubGraph",
+ "generate_random_signs",
+ "fwd_gradients",
+ "ensure_sequence",
+ "batch_execute",
+ "extract_convolution_patches",
+ "extract_pointwise_conv2d_patches",
+ "is_data_format_channel_last",
+ "matmul_sparse_dense",
+ "matmul_diag_sparse",
+]
+
+remove_undocumented(__name__, allowed_exception_list=_allowed_symbols)