aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/contrib/kfac
diff options
context:
space:
mode:
authorGravatar Vikram Tankasali <tvikram@google.com>2018-08-22 12:51:59 -0700
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2018-08-22 12:56:27 -0700
commitc73964210ced86791c9231768315fa4652abc9ba (patch)
treed3648366bd10b32e9b15eaa7d0dbcc545f95ae39 /tensorflow/contrib/kfac
parentc85e0a9829258133c84e863b5c35be9e3f9aa280 (diff)
BEGIN_PUBLIC
Delete tf.contrib.kfac. K-FAC in Tensorflow is now its own separate package. END_PUBLIC RELNOTES: n/a Automated rollback of commit 938b9a40787028c58fb548fa6ada8c0dd8180f35 PiperOrigin-RevId: 209813506
Diffstat (limited to 'tensorflow/contrib/kfac')
-rw-r--r--tensorflow/contrib/kfac/BUILD26
-rw-r--r--tensorflow/contrib/kfac/README.md93
-rw-r--r--tensorflow/contrib/kfac/__init__.py46
-rw-r--r--tensorflow/contrib/kfac/examples/BUILD80
-rw-r--r--tensorflow/contrib/kfac/examples/convnet.py667
-rw-r--r--tensorflow/contrib/kfac/examples/convnet_mnist_distributed_main.py62
-rw-r--r--tensorflow/contrib/kfac/examples/convnet_mnist_multi_tower_main.py48
-rw-r--r--tensorflow/contrib/kfac/examples/convnet_mnist_single_main.py39
-rw-r--r--tensorflow/contrib/kfac/examples/mlp.py354
-rw-r--r--tensorflow/contrib/kfac/examples/mlp_mnist_main.py64
-rw-r--r--tensorflow/contrib/kfac/examples/mnist.py69
-rw-r--r--tensorflow/contrib/kfac/examples/tests/BUILD52
-rw-r--r--tensorflow/contrib/kfac/examples/tests/convnet_test.py166
-rw-r--r--tensorflow/contrib/kfac/examples/tests/mlp_test.py63
-rw-r--r--tensorflow/contrib/kfac/examples/tests/mnist_test.py72
-rw-r--r--tensorflow/contrib/kfac/g3doc/autoencoder.pngbin54204 -> 0 bytes
-rw-r--r--tensorflow/contrib/kfac/python/kernel_tests/BUILD160
-rw-r--r--tensorflow/contrib/kfac/python/kernel_tests/estimator_test.py310
-rw-r--r--tensorflow/contrib/kfac/python/kernel_tests/fisher_blocks_test.py1018
-rw-r--r--tensorflow/contrib/kfac/python/kernel_tests/fisher_factors_test.py955
-rw-r--r--tensorflow/contrib/kfac/python/kernel_tests/layer_collection_test.py597
-rw-r--r--tensorflow/contrib/kfac/python/kernel_tests/loss_functions_test.py190
-rw-r--r--tensorflow/contrib/kfac/python/kernel_tests/op_queue_test.py50
-rw-r--r--tensorflow/contrib/kfac/python/kernel_tests/optimizer_test.py219
-rw-r--r--tensorflow/contrib/kfac/python/kernel_tests/utils_test.py410
-rw-r--r--tensorflow/contrib/kfac/python/ops/BUILD263
-rw-r--r--tensorflow/contrib/kfac/python/ops/curvature_matrix_vector_products.py183
-rw-r--r--tensorflow/contrib/kfac/python/ops/curvature_matrix_vector_products_lib.py30
-rw-r--r--tensorflow/contrib/kfac/python/ops/estimator.py516
-rw-r--r--tensorflow/contrib/kfac/python/ops/estimator_lib.py31
-rw-r--r--tensorflow/contrib/kfac/python/ops/fisher_blocks.py1752
-rw-r--r--tensorflow/contrib/kfac/python/ops/fisher_blocks_lib.py45
-rw-r--r--tensorflow/contrib/kfac/python/ops/fisher_factors.py1830
-rw-r--r--tensorflow/contrib/kfac/python/ops/fisher_factors_lib.py38
-rw-r--r--tensorflow/contrib/kfac/python/ops/layer_collection.py1269
-rw-r--r--tensorflow/contrib/kfac/python/ops/layer_collection_lib.py46
-rw-r--r--tensorflow/contrib/kfac/python/ops/linear_operator.py95
-rw-r--r--tensorflow/contrib/kfac/python/ops/loss_functions.py754
-rw-r--r--tensorflow/contrib/kfac/python/ops/loss_functions_lib.py39
-rw-r--r--tensorflow/contrib/kfac/python/ops/op_queue.py69
-rw-r--r--tensorflow/contrib/kfac/python/ops/op_queue_lib.py30
-rw-r--r--tensorflow/contrib/kfac/python/ops/optimizer.py727
-rw-r--r--tensorflow/contrib/kfac/python/ops/optimizer_lib.py30
-rw-r--r--tensorflow/contrib/kfac/python/ops/placement.py114
-rw-r--r--tensorflow/contrib/kfac/python/ops/utils.py709
-rw-r--r--tensorflow/contrib/kfac/python/ops/utils_lib.py50
46 files changed, 1 insertions, 14429 deletions
diff --git a/tensorflow/contrib/kfac/BUILD b/tensorflow/contrib/kfac/BUILD
deleted file mode 100644
index b719046b37..0000000000
--- a/tensorflow/contrib/kfac/BUILD
+++ /dev/null
@@ -1,26 +0,0 @@
-# Description:
-# Contains KfacOptimizer, an implementation of the K-FAC optimization
-# algorithm in TensorFlow.
-package(default_visibility = ["//visibility:public"])
-
-licenses(["notice"]) # Apache 2.0
-
-exports_files(["LICENSE"])
-
-py_library(
- name = "kfac",
- srcs = ["__init__.py"],
- srcs_version = "PY2AND3",
- deps = [
- "//tensorflow/contrib/kfac/python/ops:curvature_matrix_vector_products_lib",
- "//tensorflow/contrib/kfac/python/ops:fisher_blocks_lib",
- "//tensorflow/contrib/kfac/python/ops:fisher_estimator_lib",
- "//tensorflow/contrib/kfac/python/ops:fisher_factors_lib",
- "//tensorflow/contrib/kfac/python/ops:kfac_optimizer_lib",
- "//tensorflow/contrib/kfac/python/ops:layer_collection_lib",
- "//tensorflow/contrib/kfac/python/ops:loss_functions_lib",
- "//tensorflow/contrib/kfac/python/ops:op_queue_lib",
- "//tensorflow/contrib/kfac/python/ops:utils_lib",
- "//tensorflow/python:util",
- ],
-)
diff --git a/tensorflow/contrib/kfac/README.md b/tensorflow/contrib/kfac/README.md
index 102626925d..42b91d0313 100644
--- a/tensorflow/contrib/kfac/README.md
+++ b/tensorflow/contrib/kfac/README.md
@@ -1,94 +1,3 @@
# K-FAC: Kronecker-Factored Approximate Curvature
-# <font color="red", size=10><u>WARNING: </u></font>
-# ==third_party/tensorflow/contrib/kfac is deprecated. This will be==
-# ==removed on 15-07-2018. <!-- STY:begin_strip_and_replace -->Please import third_party/tensorflow_kfac.==
-# ==<!-- STY:end_strip_and_replace Please check https://github.com/tensorflow/kfac. -->==
-
-**K-FAC in TensorFlow** is an implementation of [K-FAC][kfac-paper], an
-approximate second-order optimization method, in TensorFlow. When applied to
-feedforward and convolutional neural networks, K-FAC can converge `>3.5x`
-faster in `>14x` fewer iterations than SGD with Momentum.
-
-[kfac-paper]: https://arxiv.org/abs/1503.05671
-
-## What is K-FAC?
-
-K-FAC, short for "Kronecker-factored Approximate Curvature", is an approximation
-to the [Natural Gradient][natural_gradient] algorithm designed specifically for
-neural networks. It maintains a block-diagonal approximation to the [Fisher
-Information matrix][fisher_information], whose inverse preconditions the
-gradient.
-
-K-FAC can be used in place of SGD, Adam, and other `Optimizer` implementations.
-Experimentally, K-FAC converges `>3.5x` faster than well-tuned SGD.
-
-Unlike most optimizers, K-FAC exploits structure in the model itself (e.g. "What
-are the weights for layer i?"). As such, you must add some additional code while
-constructing your model to use K-FAC.
-
-[natural_gradient]: http://www.mitpressjournals.org/doi/abs/10.1162/089976698300017746
-[fisher_information]: https://en.wikipedia.org/wiki/Fisher_information#Matrix_form
-
-## Why should I use K-FAC?
-
-K-FAC can take advantage of the curvature of the optimization problem, resulting
-in **faster training**. For an 8-layer Autoencoder, K-FAC converges to the same
-loss as SGD with Momentum in 3.8x fewer seconds and 14.7x fewer updates. See how
-training loss changes as a function of number of epochs, steps, and seconds:
-
-![autoencoder](g3doc/autoencoder.png)
-
-## Is K-FAC for me?
-
-If you have a feedforward or convolutional model for classification that is
-converging too slowly, K-FAC is for you. K-FAC can be used in your model if:
-
-* Your model defines a posterior distribution.
-* Your model uses only fully-connected or convolutional layers (residual
- connections OK).
-* You are training on CPU or GPU.
-* You can modify model code to register layers with K-FAC.
-
-## How do I use K-FAC?
-
-Using K-FAC requires three steps:
-
-1. Registering layer inputs, weights, and pre-activations with a
- `LayerCollection`.
-1. Minimizing the loss with a `KfacOptimizer`.
-1. Keeping K-FAC's preconditioner updated.
-
-```python
-# Build model.
-w = tf.get_variable("w", ...)
-b = tf.get_variable("b", ...)
-logits = tf.matmul(x, w) + b
-loss = tf.reduce_mean(
- tf.nn.softmax_cross_entropy_with_logits(labels=y, logits=logits))
-
-# Register layers.
-layer_collection = LayerCollection()
-layer_collection.register_fully_connected((w, b), x, logits)
-layer_collection.register_categorical_predictive_distribution(logits)
-
-# Construct training ops.
-optimizer = KfacOptimizer(..., layer_collection=layer_collection)
-train_op = optimizer.minimize(loss)
-
-# Minimize loss.
-with tf.Session() as sess:
- ...
- sess.run([train_op, optimizer.cov_update_op, optimizer.inv_update_op])
-```
-
-See [`examples/`](https://www.tensorflow.org/code/tensorflow/contrib/kfac/examples/) for runnable, end-to-end illustrations.
-
-## Authors
-
-- Alok Aggarwal
-- Daniel Duckworth
-- James Martens
-- Matthew Johnson
-- Olga Wichrowska
-- Roger Grosse
+## KFAC moved to third_party/tensorflow_kfac.
diff --git a/tensorflow/contrib/kfac/__init__.py b/tensorflow/contrib/kfac/__init__.py
deleted file mode 100644
index 1ea354e6cd..0000000000
--- a/tensorflow/contrib/kfac/__init__.py
+++ /dev/null
@@ -1,46 +0,0 @@
-# Copyright 2017 The TensorFlow Authors. All Rights Reserved.
-#
-# Licensed under the Apache License, Version 2.0 (the "License");
-# you may not use this file except in compliance with the License.
-# You may obtain a copy of the License at
-#
-# http://www.apache.org/licenses/LICENSE-2.0
-#
-# Unless required by applicable law or agreed to in writing, software
-# distributed under the License is distributed on an "AS IS" BASIS,
-# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-# See the License for the specific language governing permissions and
-# limitations under the License.
-# ==============================================================================
-"""Kronecker-factored Approximate Curvature Optimizer."""
-
-from __future__ import absolute_import
-from __future__ import division
-from __future__ import print_function
-
-# pylint: disable=unused-import,line-too-long
-from tensorflow.contrib.kfac.python.ops import curvature_matrix_vector_products_lib as curvature_matrix_vector_products
-from tensorflow.contrib.kfac.python.ops import estimator_lib as estimator
-from tensorflow.contrib.kfac.python.ops import fisher_blocks_lib as fisher_blocks
-from tensorflow.contrib.kfac.python.ops import fisher_factors_lib as fisher_factors
-from tensorflow.contrib.kfac.python.ops import layer_collection_lib as layer_collection
-from tensorflow.contrib.kfac.python.ops import loss_functions_lib as loss_functions
-from tensorflow.contrib.kfac.python.ops import op_queue_lib as op_queue
-from tensorflow.contrib.kfac.python.ops import optimizer_lib as optimizer
-from tensorflow.contrib.kfac.python.ops import utils_lib as utils
-from tensorflow.python.util.all_util import remove_undocumented
-# pylint: enable=unused-import,line-too-long
-
-_allowed_symbols = [
- "curvature_matrix_vector_products",
- "estimator",
- "fisher_blocks",
- "fisher_factors",
- "layer_collection",
- "loss_functions",
- "op_queue",
- "optimizer",
- "utils",
-]
-
-remove_undocumented(__name__, allowed_exception_list=_allowed_symbols)
diff --git a/tensorflow/contrib/kfac/examples/BUILD b/tensorflow/contrib/kfac/examples/BUILD
deleted file mode 100644
index 8186fa1c62..0000000000
--- a/tensorflow/contrib/kfac/examples/BUILD
+++ /dev/null
@@ -1,80 +0,0 @@
-package(default_visibility = [
- "//learning/brain/contrib/kfac/examples:__subpackages__",
- "//tensorflow/contrib/kfac/examples:__subpackages__",
-])
-
-licenses(["notice"]) # Apache 2.0
-
-exports_files(["LICENSE"])
-
-py_binary(
- name = "mlp_mnist_main",
- srcs = ["mlp_mnist_main.py"],
- srcs_version = "PY2AND3",
- deps = [
- ":mlp",
- "//tensorflow:tensorflow_py",
- ],
-)
-
-py_library(
- name = "mlp",
- srcs = ["mlp.py"],
- srcs_version = "PY2AND3",
- deps = [
- ":mnist",
- "//tensorflow:tensorflow_py",
- ],
-)
-
-py_binary(
- name = "convnet_mnist_single_main",
- srcs = ["convnet_mnist_single_main.py"],
- srcs_version = "PY2AND3",
- deps = [
- ":convnet",
- "//tensorflow:tensorflow_py",
- ],
-)
-
-py_binary(
- name = "convnet_mnist_multi_tower_main",
- srcs = ["convnet_mnist_multi_tower_main.py"],
- srcs_version = "PY2AND3",
- deps = [
- ":convnet",
- "//tensorflow:tensorflow_py",
- ],
-)
-
-py_binary(
- name = "convnet_mnist_distributed_main",
- srcs = ["convnet_mnist_distributed_main.py"],
- srcs_version = "PY2AND3",
- deps = [
- ":convnet",
- "//tensorflow:tensorflow_py",
- ],
-)
-
-py_library(
- name = "convnet",
- srcs = ["convnet.py"],
- srcs_version = "PY2AND3",
- deps = [
- ":mlp",
- ":mnist",
- "//tensorflow:tensorflow_py",
- "//third_party/py/numpy",
- ],
-)
-
-py_library(
- name = "mnist",
- srcs = ["mnist.py"],
- srcs_version = "PY2AND3",
- deps = [
- "//tensorflow:tensorflow_py",
- "//third_party/py/numpy",
- ],
-)
diff --git a/tensorflow/contrib/kfac/examples/convnet.py b/tensorflow/contrib/kfac/examples/convnet.py
deleted file mode 100644
index 44e01e1aeb..0000000000
--- a/tensorflow/contrib/kfac/examples/convnet.py
+++ /dev/null
@@ -1,667 +0,0 @@
-# Copyright 2017 The TensorFlow Authors. All Rights Reserved.
-#
-# Licensed under the Apache License, Version 2.0 (the "License");
-# you may not use this file except in compliance with the License.
-# You may obtain a copy of the License at
-#
-# http://www.apache.org/licenses/LICENSE-2.0
-#
-# Unless required by applicable law or agreed to in writing, software
-# distributed under the License is distributed on an "AS IS" BASIS,
-# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-# See the License for the specific language governing permissions and
-# limitations under the License.
-# ==============================================================================
-r"""Train a ConvNet on MNIST using K-FAC.
-
-This library fits a 5-layer ConvNet on MNIST using K-FAC. The model has the
-following structure,
-
-- Conv Layer: 5x5 kernel, 16 output channels.
-- Max Pool: 3x3 kernel, stride 2.
-- Conv Layer: 5x5 kernel, 16 output channels.
-- Max Pool: 3x3 kernel, stride 2.
-- Linear: 10 output dims.
-
-After 3k~6k steps, this should reach perfect accuracy on the training set.
-"""
-
-from __future__ import absolute_import
-from __future__ import division
-from __future__ import print_function
-
-import os
-
-import numpy as np
-import tensorflow as tf
-
-from tensorflow.contrib.kfac.examples import mlp
-from tensorflow.contrib.kfac.examples import mnist
-from tensorflow.contrib.kfac.python.ops import optimizer as opt
-
-
-lc = tf.contrib.kfac.layer_collection
-oq = tf.contrib.kfac.op_queue
-opt = tf.contrib.kfac.optimizer
-
-__all__ = [
- "conv_layer",
- "max_pool_layer",
- "linear_layer",
- "build_model",
- "minimize_loss_single_machine",
- "distributed_grads_only_and_ops_chief_worker",
- "distributed_grads_and_ops_dedicated_workers",
- "train_mnist_single_machine",
- "train_mnist_distributed_sync_replicas",
- "train_mnist_multitower"
-]
-
-
-# Inverse update ops will be run every _INVERT_EVRY iterations.
-_INVERT_EVERY = 10
-
-
-def conv_layer(layer_id, inputs, kernel_size, out_channels):
- """Builds a convolutional layer with ReLU non-linearity.
-
- Args:
- layer_id: int. Integer ID for this layer's variables.
- inputs: Tensor of shape [num_examples, width, height, in_channels]. Each row
- corresponds to a single example.
- kernel_size: int. Width and height of the convolution kernel. The kernel is
- assumed to be square.
- out_channels: int. Number of output features per pixel.
-
- Returns:
- preactivations: Tensor of shape [num_examples, width, height, out_channels].
- Values of the layer immediately before the activation function.
- activations: Tensor of shape [num_examples, width, height, out_channels].
- Values of the layer immediately after the activation function.
- params: Tuple of (kernel, bias), parameters for this layer.
- """
- # TODO(b/67004004): Delete this function and rely on tf.layers exclusively.
- layer = tf.layers.Conv2D(
- out_channels,
- kernel_size=[kernel_size, kernel_size],
- kernel_initializer=tf.random_normal_initializer(stddev=0.01),
- padding="SAME",
- name="conv_%d" % layer_id)
- preactivations = layer(inputs)
- activations = tf.nn.relu(preactivations)
-
- # layer.weights is a list. This converts it a (hashable) tuple.
- return preactivations, activations, (layer.kernel, layer.bias)
-
-
-def max_pool_layer(layer_id, inputs, kernel_size, stride):
- """Build a max-pooling layer.
-
- Args:
- layer_id: int. Integer ID for this layer's variables.
- inputs: Tensor of shape [num_examples, width, height, in_channels]. Each row
- corresponds to a single example.
- kernel_size: int. Width and height to pool over per input channel. The
- kernel is assumed to be square.
- stride: int. Step size between pooling operations.
-
- Returns:
- Tensor of shape [num_examples, width/stride, height/stride, out_channels].
- Result of applying max pooling to 'inputs'.
- """
- # TODO(b/67004004): Delete this function and rely on tf.layers exclusively.
- with tf.variable_scope("pool_%d" % layer_id):
- return tf.nn.max_pool(
- inputs, [1, kernel_size, kernel_size, 1], [1, stride, stride, 1],
- padding="SAME",
- name="pool")
-
-
-def linear_layer(layer_id, inputs, output_size):
- """Builds the final linear layer for an MNIST classification problem.
-
- Args:
- layer_id: int. Integer ID for this layer's variables.
- inputs: Tensor of shape [num_examples, width, height, in_channels]. Each row
- corresponds to a single example.
- output_size: int. Number of output dims per example.
-
- Returns:
- activations: Tensor of shape [num_examples, output_size]. Values of the
- layer immediately after the activation function.
- params: Tuple of (weights, bias), parameters for this layer.
- """
- # TODO(b/67004004): Delete this function and rely on tf.layers exclusively.
- pre, _, params = mlp.fc_layer(layer_id, inputs, output_size)
- return pre, params
-
-
-def build_model(examples, labels, num_labels, layer_collection):
- """Builds a ConvNet classification model.
-
- Args:
- examples: Tensor of shape [num_examples, num_features]. Represents inputs of
- model.
- labels: Tensor of shape [num_examples]. Contains integer IDs to be predicted
- by softmax for each example.
- num_labels: int. Number of distinct values 'labels' can take on.
- layer_collection: LayerCollection instance. Layers will be registered here.
-
- Returns:
- loss: 0-D Tensor representing loss to be minimized.
- accuracy: 0-D Tensor representing model's accuracy.
- """
- # Build a ConvNet. For each layer with parameters, we'll keep track of the
- # preactivations, activations, weights, and bias.
- tf.logging.info("Building model.")
- pre0, act0, params0 = conv_layer(
- layer_id=0, inputs=examples, kernel_size=5, out_channels=16)
- act1 = max_pool_layer(layer_id=1, inputs=act0, kernel_size=3, stride=2)
- pre2, act2, params2 = conv_layer(
- layer_id=2, inputs=act1, kernel_size=5, out_channels=16)
- act3 = max_pool_layer(layer_id=3, inputs=act2, kernel_size=3, stride=2)
- flat_act3 = tf.reshape(act3, shape=[-1, int(np.prod(act3.shape[1:4]))])
- logits, params4 = linear_layer(
- layer_id=4, inputs=flat_act3, output_size=num_labels)
- loss = tf.reduce_mean(
- tf.nn.sparse_softmax_cross_entropy_with_logits(
- labels=labels, logits=logits))
- accuracy = tf.reduce_mean(
- tf.cast(tf.equal(labels, tf.argmax(logits, axis=1)), dtype=tf.float32))
-
- with tf.device("/cpu:0"):
- tf.summary.scalar("loss", loss)
- tf.summary.scalar("accuracy", accuracy)
-
- # Register parameters. K-FAC needs to know about the inputs, outputs, and
- # parameters of each conv/fully connected layer and the logits powering the
- # posterior probability over classes.
- tf.logging.info("Building LayerCollection.")
- layer_collection.register_conv2d(params0, (1, 1, 1, 1), "SAME", examples,
- pre0)
- layer_collection.register_conv2d(params2, (1, 1, 1, 1), "SAME", act1, pre2)
- layer_collection.register_fully_connected(params4, flat_act3, logits)
- layer_collection.register_categorical_predictive_distribution(
- logits, name="logits")
-
- return loss, accuracy
-
-
-def minimize_loss_single_machine(loss,
- accuracy,
- layer_collection,
- device="/gpu:0",
- session_config=None):
- """Minimize loss with K-FAC on a single machine.
-
- A single Session is responsible for running all of K-FAC's ops. The covariance
- and inverse update ops are placed on `device`. All model variables are on CPU.
-
- Args:
- loss: 0-D Tensor. Loss to be minimized.
- accuracy: 0-D Tensor. Accuracy of classifier on current minibatch.
- layer_collection: LayerCollection instance describing model architecture.
- Used by K-FAC to construct preconditioner.
- device: string, Either '/cpu:0' or '/gpu:0'. The covariance and inverse
- update ops are run on this device.
- session_config: None or tf.ConfigProto. Configuration for tf.Session().
-
- Returns:
- final value for 'accuracy'.
- """
- # Train with K-FAC.
- g_step = tf.train.get_or_create_global_step()
- optimizer = opt.KfacOptimizer(
- learning_rate=0.0001,
- cov_ema_decay=0.95,
- damping=0.001,
- layer_collection=layer_collection,
- placement_strategy="round_robin",
- cov_devices=[device],
- inv_devices=[device],
- momentum=0.9)
- (cov_update_thunks,
- inv_update_thunks) = optimizer.make_vars_and_create_op_thunks()
-
- def make_update_op(update_thunks):
- update_ops = [thunk() for thunk in update_thunks]
- return tf.group(*update_ops)
-
- cov_update_op = make_update_op(cov_update_thunks)
- with tf.control_dependencies([cov_update_op]):
- inverse_op = tf.cond(
- tf.equal(tf.mod(g_step, _INVERT_EVERY), 0),
- lambda: make_update_op(inv_update_thunks), tf.no_op)
- with tf.control_dependencies([inverse_op]):
- with tf.device(device):
- train_op = optimizer.minimize(loss, global_step=g_step)
-
- tf.logging.info("Starting training.")
- with tf.train.MonitoredTrainingSession(config=session_config) as sess:
- while not sess.should_stop():
- global_step_, loss_, accuracy_, _ = sess.run(
- [g_step, loss, accuracy, train_op])
-
- if global_step_ % _INVERT_EVERY == 0:
- tf.logging.info("global_step: %d | loss: %f | accuracy: %s",
- global_step_, loss_, accuracy_)
-
- return accuracy_
-
-
-def _is_gradient_task(task_id, num_tasks):
- """Returns True if this task should update the weights."""
- if num_tasks < 3:
- return True
- return 0 <= task_id < 0.6 * num_tasks
-
-
-def _is_cov_update_task(task_id, num_tasks):
- """Returns True if this task should update K-FAC's covariance matrices."""
- if num_tasks < 3:
- return False
- return 0.6 * num_tasks <= task_id < num_tasks - 1
-
-
-def _is_inv_update_task(task_id, num_tasks):
- """Returns True if this task should update K-FAC's preconditioner."""
- if num_tasks < 3:
- return False
- return task_id == num_tasks - 1
-
-
-def _num_gradient_tasks(num_tasks):
- """Number of tasks that will update weights."""
- if num_tasks < 3:
- return num_tasks
- return int(np.ceil(0.6 * num_tasks))
-
-
-def _make_distributed_train_op(
- task_id,
- num_worker_tasks,
- num_ps_tasks,
- layer_collection
-):
- """Creates optimizer and distributed training op.
-
- Constructs KFAC optimizer and wraps it in `sync_replicas` optimizer. Makes
- the train op.
-
- Args:
- task_id: int. Integer in [0, num_worker_tasks). ID for this worker.
- num_worker_tasks: int. Number of workers in this distributed training setup.
- num_ps_tasks: int. Number of parameter servers holding variables. If 0,
- parameter servers are not used.
- layer_collection: LayerCollection instance describing model architecture.
- Used by K-FAC to construct preconditioner.
-
- Returns:
- sync_optimizer: `tf.train.SyncReplicasOptimizer` instance which wraps KFAC
- optimizer.
- optimizer: Instance of `opt.KfacOptimizer`.
- global_step: `tensor`, Global step.
- """
- tf.logging.info("Task id : %d", task_id)
- with tf.device(tf.train.replica_device_setter(num_ps_tasks)):
- global_step = tf.train.get_or_create_global_step()
- optimizer = opt.KfacOptimizer(
- learning_rate=0.0001,
- cov_ema_decay=0.95,
- damping=0.001,
- layer_collection=layer_collection,
- momentum=0.9)
- sync_optimizer = tf.train.SyncReplicasOptimizer(
- opt=optimizer,
- replicas_to_aggregate=_num_gradient_tasks(num_worker_tasks),
- total_num_replicas=num_worker_tasks)
- return sync_optimizer, optimizer, global_step
-
-
-def distributed_grads_only_and_ops_chief_worker(
- task_id, is_chief, num_worker_tasks, num_ps_tasks, master, checkpoint_dir,
- loss, accuracy, layer_collection, invert_every=10):
- """Minimize loss with a synchronous implementation of K-FAC.
-
- All workers perform gradient computation. Chief worker applies gradient after
- averaging the gradients obtained from all the workers. All workers block
- execution until the update is applied. Chief worker runs covariance and
- inverse update ops. Covariance and inverse matrices are placed on parameter
- servers in a round robin manner. For further details on synchronous
- distributed optimization check `tf.train.SyncReplicasOptimizer`.
-
- Args:
- task_id: int. Integer in [0, num_worker_tasks). ID for this worker.
- is_chief: `boolean`, `True` if the worker is chief worker.
- num_worker_tasks: int. Number of workers in this distributed training setup.
- num_ps_tasks: int. Number of parameter servers holding variables. If 0,
- parameter servers are not used.
- master: string. IP and port of TensorFlow runtime process. Set to empty
- string to run locally.
- checkpoint_dir: string or None. Path to store checkpoints under.
- loss: 0-D Tensor. Loss to be minimized.
- accuracy: dict mapping strings to 0-D Tensors. Additional accuracy to
- run with each step.
- layer_collection: LayerCollection instance describing model architecture.
- Used by K-FAC to construct preconditioner.
- invert_every: `int`, Number of steps between update the inverse.
-
- Returns:
- final value for 'accuracy'.
-
- Raises:
- ValueError: if task_id >= num_worker_tasks.
- """
-
- sync_optimizer, optimizer, global_step = _make_distributed_train_op(
- task_id, num_worker_tasks, num_ps_tasks, layer_collection)
- (cov_update_thunks,
- inv_update_thunks) = optimizer.make_vars_and_create_op_thunks()
-
- tf.logging.info("Starting training.")
- hooks = [sync_optimizer.make_session_run_hook(is_chief)]
-
- def make_update_op(update_thunks):
- update_ops = [thunk() for thunk in update_thunks]
- return tf.group(*update_ops)
-
- if is_chief:
- cov_update_op = make_update_op(cov_update_thunks)
- with tf.control_dependencies([cov_update_op]):
- inverse_op = tf.cond(
- tf.equal(tf.mod(global_step, invert_every), 0),
- lambda: make_update_op(inv_update_thunks),
- tf.no_op)
- with tf.control_dependencies([inverse_op]):
- train_op = sync_optimizer.minimize(loss, global_step=global_step)
- else:
- train_op = sync_optimizer.minimize(loss, global_step=global_step)
-
- with tf.train.MonitoredTrainingSession(
- master=master,
- is_chief=is_chief,
- checkpoint_dir=checkpoint_dir,
- hooks=hooks,
- stop_grace_period_secs=0) as sess:
- while not sess.should_stop():
- global_step_, loss_, accuracy_, _ = sess.run(
- [global_step, loss, accuracy, train_op])
- tf.logging.info("global_step: %d | loss: %f | accuracy: %s", global_step_,
- loss_, accuracy_)
- return accuracy_
-
-
-def distributed_grads_and_ops_dedicated_workers(
- task_id, is_chief, num_worker_tasks, num_ps_tasks, master, checkpoint_dir,
- loss, accuracy, layer_collection):
- """Minimize loss with a synchronous implementation of K-FAC.
-
- Different workers are responsible for different parts of K-FAC's Ops. The
- first 60% of tasks compute gradients; the next 20% accumulate covariance
- statistics; the last 20% invert the matrices used to precondition gradients.
- The chief worker applies the gradient .
-
- Args:
- task_id: int. Integer in [0, num_worker_tasks). ID for this worker.
- is_chief: `boolean`, `True` if the worker is chief worker.
- num_worker_tasks: int. Number of workers in this distributed training setup.
- num_ps_tasks: int. Number of parameter servers holding variables. If 0,
- parameter servers are not used.
- master: string. IP and port of TensorFlow runtime process. Set to empty
- string to run locally.
- checkpoint_dir: string or None. Path to store checkpoints under.
- loss: 0-D Tensor. Loss to be minimized.
- accuracy: dict mapping strings to 0-D Tensors. Additional accuracy to
- run with each step.
- layer_collection: LayerCollection instance describing model architecture.
- Used by K-FAC to construct preconditioner.
-
- Returns:
- final value for 'accuracy'.
-
- Raises:
- ValueError: if task_id >= num_worker_tasks.
- """
- sync_optimizer, optimizer, global_step = _make_distributed_train_op(
- task_id, num_worker_tasks, num_ps_tasks, layer_collection)
- _, cov_update_op, inv_update_ops, _, _, _ = optimizer.make_ops_and_vars()
- train_op = sync_optimizer.minimize(loss, global_step=global_step)
- inv_update_queue = oq.OpQueue(inv_update_ops)
-
- tf.logging.info("Starting training.")
- is_chief = (task_id == 0)
- hooks = [sync_optimizer.make_session_run_hook(is_chief)]
- with tf.train.MonitoredTrainingSession(
- master=master,
- is_chief=is_chief,
- checkpoint_dir=checkpoint_dir,
- hooks=hooks,
- stop_grace_period_secs=0) as sess:
- while not sess.should_stop():
- # Choose which op this task is responsible for running.
- if _is_gradient_task(task_id, num_worker_tasks):
- learning_op = train_op
- elif _is_cov_update_task(task_id, num_worker_tasks):
- learning_op = cov_update_op
- elif _is_inv_update_task(task_id, num_worker_tasks):
- # TODO(duckworthd): Running this op before cov_update_op has been run a
- # few times can result in "InvalidArgumentError: Cholesky decomposition
- # was not successful." Delay running this op until cov_update_op has
- # been run a few times.
- learning_op = inv_update_queue.next_op(sess)
- else:
- raise ValueError("Which op should task %d do?" % task_id)
-
- global_step_, loss_, accuracy_, _ = sess.run(
- [global_step, loss, accuracy, learning_op])
- tf.logging.info("global_step: %d | loss: %f | accuracy: %s", global_step_,
- loss_, accuracy_)
-
- return accuracy_
-
-
-def train_mnist_single_machine(data_dir,
- num_epochs,
- use_fake_data=False,
- device="/gpu:0"):
- """Train a ConvNet on MNIST.
-
- Args:
- data_dir: string. Directory to read MNIST examples from.
- num_epochs: int. Number of passes to make over the training set.
- use_fake_data: bool. If True, generate a synthetic dataset.
- device: string, Either '/cpu:0' or '/gpu:0'. The covariance and inverse
- update ops are run on this device.
-
- Returns:
- accuracy of model on the final minibatch of training data.
- """
- # Load a dataset.
- tf.logging.info("Loading MNIST into memory.")
- examples, labels = mnist.load_mnist(
- data_dir,
- num_epochs=num_epochs,
- batch_size=128,
- use_fake_data=use_fake_data,
- flatten_images=False)
-
- # Build a ConvNet.
- layer_collection = lc.LayerCollection()
- loss, accuracy = build_model(
- examples, labels, num_labels=10, layer_collection=layer_collection)
-
- # Fit model.
- return minimize_loss_single_machine(
- loss, accuracy, layer_collection, device=device)
-
-
-def train_mnist_multitower(data_dir, num_epochs, num_towers,
- use_fake_data=True, devices=None):
- """Train a ConvNet on MNIST.
-
- Training data is split equally among the towers. Each tower computes loss on
- its own batch of data and the loss is aggregated on the CPU. The model
- variables are placed on first tower. The covariance and inverse update ops
- and variables are placed on GPUs in a round robin manner.
-
- Args:
- data_dir: string. Directory to read MNIST examples from.
- num_epochs: int. Number of passes to make over the training set.
- num_towers: int. Number of CPUs to split inference across.
- use_fake_data: bool. If True, generate a synthetic dataset.
- devices: string, Either list of CPU or GPU. The covariance and inverse
- update ops are run on this device.
-
- Returns:
- accuracy of model on the final minibatch of training data.
- """
- if devices:
- device_count = {"GPU": num_towers}
- else:
- device_count = {"CPU": num_towers}
-
- devices = devices or [
- "/cpu:{}".format(tower_id) for tower_id in range(num_towers)
- ]
- # Load a dataset.
- tf.logging.info("Loading MNIST into memory.")
- tower_batch_size = 128
- batch_size = tower_batch_size * num_towers
- tf.logging.info(
- ("Loading MNIST into memory. Using batch_size = %d = %d towers * %d "
- "tower batch size.") % (batch_size, num_towers, tower_batch_size))
- examples, labels = mnist.load_mnist(
- data_dir,
- num_epochs=num_epochs,
- batch_size=batch_size,
- use_fake_data=use_fake_data,
- flatten_images=False)
-
- # Split minibatch across towers.
- examples = tf.split(examples, num_towers)
- labels = tf.split(labels, num_towers)
-
- # Build an MLP. Each tower's layers will be added to the LayerCollection.
- layer_collection = lc.LayerCollection()
- tower_results = []
- for tower_id in range(num_towers):
- with tf.device(devices[tower_id]):
- with tf.name_scope("tower%d" % tower_id):
- with tf.variable_scope(tf.get_variable_scope(), reuse=(tower_id > 0)):
- tf.logging.info("Building tower %d." % tower_id)
- tower_results.append(
- build_model(examples[tower_id], labels[tower_id], 10,
- layer_collection))
- losses, accuracies = zip(*tower_results)
-
- # Average across towers.
- loss = tf.reduce_mean(losses)
- accuracy = tf.reduce_mean(accuracies)
-
- # Fit model.
-
- session_config = tf.ConfigProto(
- allow_soft_placement=False,
- device_count=device_count,
- )
-
- g_step = tf.train.get_or_create_global_step()
- optimizer = opt.KfacOptimizer(
- learning_rate=0.0001,
- cov_ema_decay=0.95,
- damping=0.001,
- layer_collection=layer_collection,
- placement_strategy="round_robin",
- cov_devices=devices,
- inv_devices=devices,
- momentum=0.9)
- (cov_update_thunks,
- inv_update_thunks) = optimizer.make_vars_and_create_op_thunks()
-
- def make_update_op(update_thunks):
- update_ops = [thunk() for thunk in update_thunks]
- return tf.group(*update_ops)
-
- cov_update_op = make_update_op(cov_update_thunks)
- with tf.control_dependencies([cov_update_op]):
- inverse_op = tf.cond(
- tf.equal(tf.mod(g_step, _INVERT_EVERY), 0),
- lambda: make_update_op(inv_update_thunks), tf.no_op)
- with tf.control_dependencies([inverse_op]):
- train_op = optimizer.minimize(loss, global_step=g_step)
-
- tf.logging.info("Starting training.")
- with tf.train.MonitoredTrainingSession(config=session_config) as sess:
- while not sess.should_stop():
- global_step_, loss_, accuracy_, _ = sess.run(
- [g_step, loss, accuracy, train_op])
-
- if global_step_ % _INVERT_EVERY == 0:
- tf.logging.info("global_step: %d | loss: %f | accuracy: %s",
- global_step_, loss_, accuracy_)
-
-
-def train_mnist_distributed_sync_replicas(task_id,
- is_chief,
- num_worker_tasks,
- num_ps_tasks,
- master,
- data_dir,
- num_epochs,
- op_strategy,
- use_fake_data=False):
- """Train a ConvNet on MNIST using Sync replicas optimizer.
-
- Args:
- task_id: int. Integer in [0, num_worker_tasks). ID for this worker.
- is_chief: `boolean`, `True` if the worker is chief worker.
- num_worker_tasks: int. Number of workers in this distributed training setup.
- num_ps_tasks: int. Number of parameter servers holding variables.
- master: string. IP and port of TensorFlow runtime process.
- data_dir: string. Directory to read MNIST examples from.
- num_epochs: int. Number of passes to make over the training set.
- op_strategy: `string`, Strategy to run the covariance and inverse
- ops. If op_strategy == `chief_worker` then covariance and inverse
- update ops are run on chief worker otherwise they are run on dedicated
- workers.
-
- use_fake_data: bool. If True, generate a synthetic dataset.
-
- Returns:
- accuracy of model on the final minibatch of training data.
-
- Raises:
- ValueError: If `op_strategy` not in ["chief_worker", "dedicated_workers"].
- """
- # Load a dataset.
- tf.logging.info("Loading MNIST into memory.")
- examples, labels = mnist.load_mnist(
- data_dir,
- num_epochs=num_epochs,
- batch_size=128,
- use_fake_data=use_fake_data,
- flatten_images=False)
-
- # Build a ConvNet.
- layer_collection = lc.LayerCollection()
- with tf.device(tf.train.replica_device_setter(num_ps_tasks)):
- loss, accuracy = build_model(
- examples, labels, num_labels=10, layer_collection=layer_collection)
-
- # Fit model.
- checkpoint_dir = None if data_dir is None else os.path.join(data_dir, "kfac")
- if op_strategy == "chief_worker":
- return distributed_grads_only_and_ops_chief_worker(
- task_id, is_chief, num_worker_tasks, num_ps_tasks, master,
- checkpoint_dir, loss, accuracy, layer_collection)
- elif op_strategy == "dedicated_workers":
- return distributed_grads_and_ops_dedicated_workers(
- task_id, is_chief, num_worker_tasks, num_ps_tasks, master,
- checkpoint_dir, loss, accuracy, layer_collection)
- else:
- raise ValueError("Only supported op strategies are : {}, {}".format(
- "chief_worker", "dedicated_workers"))
-
-
-if __name__ == "__main__":
- tf.app.run()
diff --git a/tensorflow/contrib/kfac/examples/convnet_mnist_distributed_main.py b/tensorflow/contrib/kfac/examples/convnet_mnist_distributed_main.py
deleted file mode 100644
index b4c2d4a9e9..0000000000
--- a/tensorflow/contrib/kfac/examples/convnet_mnist_distributed_main.py
+++ /dev/null
@@ -1,62 +0,0 @@
-# Copyright 2017 The TensorFlow Authors. All Rights Reserved.
-#
-# Licensed under the Apache License, Version 2.0 (the "License");
-# you may not use this file except in compliance with the License.
-# You may obtain a copy of the License at
-#
-# http://www.apache.org/licenses/LICENSE-2.0
-#
-# Unless required by applicable law or agreed to in writing, software
-# distributed under the License is distributed on an "AS IS" BASIS,
-# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-# See the License for the specific language governing permissions and
-# limitations under the License.
-# ==============================================================================
-r"""Train a ConvNet on MNIST using K-FAC.
-
-Distributed training with sync replicas optimizer. See
-`convnet.train_mnist_distributed_sync_replicas` for details.
-"""
-
-from __future__ import absolute_import
-from __future__ import division
-from __future__ import print_function
-
-
-from absl import flags
-import tensorflow as tf
-
-from tensorflow.contrib.kfac.examples import convnet
-
-FLAGS = flags.FLAGS
-flags.DEFINE_integer("task", -1, "Task identifier")
-flags.DEFINE_string("data_dir", "/tmp/mnist", "local mnist dir")
-flags.DEFINE_string(
- "cov_inv_op_strategy", "chief_worker",
- "In dist training mode run the cov, inv ops on chief or dedicated workers."
-)
-flags.DEFINE_string("master", "local", "Session master.")
-flags.DEFINE_integer("ps_tasks", 2,
- "Number of tasks in the parameter server job.")
-flags.DEFINE_integer("replicas_to_aggregate", 5,
- "Number of replicas to aggregate.")
-flags.DEFINE_integer("worker_replicas", 5, "Number of replicas in worker job.")
-flags.DEFINE_integer("num_epochs", None, "Number of epochs.")
-
-
-def _is_chief():
- """Determines whether a job is the chief worker."""
- if "chief_worker" in FLAGS.brain_jobs:
- return FLAGS.brain_job_name == "chief_worker"
- else:
- return FLAGS.task == 0
-
-
-def main(unused_argv):
- _ = unused_argv
- convnet.train_mnist_distributed_sync_replicas(
- FLAGS.task, _is_chief(), FLAGS.worker_replicas, FLAGS.ps_tasks,
- FLAGS.master, FLAGS.data_dir, FLAGS.num_epochs, FLAGS.cov_inv_op_strategy)
-
-if __name__ == "__main__":
- tf.app.run(main=main)
diff --git a/tensorflow/contrib/kfac/examples/convnet_mnist_multi_tower_main.py b/tensorflow/contrib/kfac/examples/convnet_mnist_multi_tower_main.py
deleted file mode 100644
index 4249bf8a8d..0000000000
--- a/tensorflow/contrib/kfac/examples/convnet_mnist_multi_tower_main.py
+++ /dev/null
@@ -1,48 +0,0 @@
-# Copyright 2017 The TensorFlow Authors. All Rights Reserved.
-#
-# Licensed under the Apache License, Version 2.0 (the "License");
-# you may not use this file except in compliance with the License.
-# You may obtain a copy of the License at
-#
-# http://www.apache.org/licenses/LICENSE-2.0
-#
-# Unless required by applicable law or agreed to in writing, software
-# distributed under the License is distributed on an "AS IS" BASIS,
-# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-# See the License for the specific language governing permissions and
-# limitations under the License.
-# ==============================================================================
-r"""Train a ConvNet on MNIST using K-FAC.
-
-Multi tower training mode. See `convnet.train_mnist_multitower` for details.
-"""
-
-from __future__ import absolute_import
-from __future__ import division
-from __future__ import print_function
-
-
-from absl import flags
-import tensorflow as tf
-
-from tensorflow.contrib.kfac.examples import convnet
-
-FLAGS = flags.FLAGS
-flags.DEFINE_string("data_dir", "/tmp/multitower_1/mnist", "local mnist dir")
-flags.DEFINE_integer("num_towers", 2,
- "Number of towers for multi tower training.")
-
-
-def main(unused_argv):
- _ = unused_argv
- assert FLAGS.num_towers > 1
- devices = ["/gpu:{}".format(tower_id) for tower_id in range(FLAGS.num_towers)]
- convnet.train_mnist_multitower(
- FLAGS.data_dir,
- num_epochs=200,
- num_towers=FLAGS.num_towers,
- devices=devices)
-
-
-if __name__ == "__main__":
- tf.app.run(main=main)
diff --git a/tensorflow/contrib/kfac/examples/convnet_mnist_single_main.py b/tensorflow/contrib/kfac/examples/convnet_mnist_single_main.py
deleted file mode 100644
index 2c1f099360..0000000000
--- a/tensorflow/contrib/kfac/examples/convnet_mnist_single_main.py
+++ /dev/null
@@ -1,39 +0,0 @@
-# Copyright 2017 The TensorFlow Authors. All Rights Reserved.
-#
-# Licensed under the Apache License, Version 2.0 (the "License");
-# you may not use this file except in compliance with the License.
-# You may obtain a copy of the License at
-#
-# http://www.apache.org/licenses/LICENSE-2.0
-#
-# Unless required by applicable law or agreed to in writing, software
-# distributed under the License is distributed on an "AS IS" BASIS,
-# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-# See the License for the specific language governing permissions and
-# limitations under the License.
-# ==============================================================================
-r"""Train a ConvNet on MNIST using K-FAC.
-
-Train on single machine. See `convnet.train_mnist_single_machine` for details.
-"""
-
-from __future__ import absolute_import
-from __future__ import division
-from __future__ import print_function
-
-
-from absl import flags
-import tensorflow as tf
-
-from tensorflow.contrib.kfac.examples import convnet
-
-FLAGS = flags.FLAGS
-flags.DEFINE_string("data_dir", "/tmp/mnist", "local mnist dir")
-
-
-def main(unused_argv):
- convnet.train_mnist_single_machine(FLAGS.data_dir, num_epochs=200)
-
-
-if __name__ == "__main__":
- tf.app.run(main=main)
diff --git a/tensorflow/contrib/kfac/examples/mlp.py b/tensorflow/contrib/kfac/examples/mlp.py
deleted file mode 100644
index ea2b252a05..0000000000
--- a/tensorflow/contrib/kfac/examples/mlp.py
+++ /dev/null
@@ -1,354 +0,0 @@
-# Copyright 2017 The TensorFlow Authors. All Rights Reserved.
-#
-# Licensed under the Apache License, Version 2.0 (the "License");
-# you may not use this file except in compliance with the License.
-# You may obtain a copy of the License at
-#
-# http://www.apache.org/licenses/LICENSE-2.0
-#
-# Unless required by applicable law or agreed to in writing, software
-# distributed under the License is distributed on an "AS IS" BASIS,
-# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-# See the License for the specific language governing permissions and
-# limitations under the License.
-# ==============================================================================
-r"""Train an MLP on MNIST using K-FAC.
-
-This library fits a 3-layer, tanh-activated MLP on MNIST using K-FAC. After
-~25k steps, this should reach perfect accuracy on the training set.
-"""
-
-from __future__ import absolute_import
-from __future__ import division
-from __future__ import print_function
-
-import tensorflow as tf
-
-from tensorflow.contrib.kfac.examples import mnist
-
-lc = tf.contrib.kfac.layer_collection
-opt = tf.contrib.kfac.optimizer
-
-__all__ = [
- "fc_layer",
- "train_mnist",
- "train_mnist_multitower",
-]
-
-
-def fc_layer(layer_id, inputs, output_size):
- """Builds a fully connected layer.
-
- Args:
- layer_id: int. Integer ID for this layer's variables.
- inputs: Tensor of shape [num_examples, input_size]. Each row corresponds
- to a single example.
- output_size: int. Number of output dimensions after fully connected layer.
-
- Returns:
- preactivations: Tensor of shape [num_examples, output_size]. Values of the
- layer immediately before the activation function.
- activations: Tensor of shape [num_examples, output_size]. Values of the
- layer immediately after the activation function.
- params: Tuple of (weights, bias), parameters for this layer.
- """
- # TODO(b/67004004): Delete this function and rely on tf.layers exclusively.
- layer = tf.layers.Dense(
- output_size,
- kernel_initializer=tf.random_normal_initializer(),
- name="fc_%d" % layer_id)
- preactivations = layer(inputs)
- activations = tf.nn.tanh(preactivations)
-
- # layer.weights is a list. This converts it a (hashable) tuple.
- return preactivations, activations, (layer.kernel, layer.bias)
-
-
-def build_model(examples, labels, num_labels, layer_collection):
- """Builds an MLP classification model.
-
- Args:
- examples: Tensor of shape [num_examples, num_features]. Represents inputs of
- model.
- labels: Tensor of shape [num_examples]. Contains integer IDs to be predicted
- by softmax for each example.
- num_labels: int. Number of distinct values 'labels' can take on.
- layer_collection: LayerCollection instance describing model architecture.
-
- Returns:
- loss: 0-D Tensor representing loss to be minimized.
- accuracy: 0-D Tensor representing model's accuracy.
- """
- # Build an MLP. For each layer, we'll keep track of the preactivations,
- # activations, weights, and bias.
- pre0, act0, params0 = fc_layer(layer_id=0, inputs=examples, output_size=128)
- pre1, act1, params1 = fc_layer(layer_id=1, inputs=act0, output_size=64)
- pre2, act2, params2 = fc_layer(layer_id=2, inputs=act1, output_size=32)
- logits, _, params3 = fc_layer(layer_id=3, inputs=act2, output_size=num_labels)
- loss = tf.reduce_mean(
- tf.nn.sparse_softmax_cross_entropy_with_logits(
- labels=labels, logits=logits))
- accuracy = tf.reduce_mean(
- tf.cast(tf.equal(labels, tf.argmax(logits, axis=1)), dtype=tf.float32))
-
- # Register parameters. K-FAC needs to know about the inputs, outputs, and
- # parameters of each layer and the logits powering the posterior probability
- # over classes.
- tf.logging.info("Building LayerCollection.")
- layer_collection.register_fully_connected(params0, examples, pre0)
- layer_collection.register_fully_connected(params1, act0, pre1)
- layer_collection.register_fully_connected(params2, act1, pre2)
- layer_collection.register_fully_connected(params3, act2, logits)
- layer_collection.register_categorical_predictive_distribution(
- logits, name="logits")
-
- return loss, accuracy
-
-
-def minimize(loss, accuracy, layer_collection, num_towers, session_config=None):
- """Minimize 'loss' with KfacOptimizer.
-
- Args:
- loss: 0-D Tensor. Loss to be minimized.
- accuracy: 0-D Tensor. Accuracy of classifier on current minibatch.
- layer_collection: LayerCollection instance. Describes layers in model.
- num_towers: int. Number of CPUs to split minibatch across.
- session_config: tf.ConfigProto. Configuration for tf.Session().
-
- Returns:
- accuracy of classifier on final minibatch.
- """
- devices = tuple("/cpu:%d" % tower_id for tower_id in range(num_towers))
-
- # Train with K-FAC. We'll use a decreasing learning rate that's cut in 1/2
- # every 10k iterations.
- tf.logging.info("Building KFAC Optimizer.")
- global_step = tf.train.get_or_create_global_step()
- optimizer = opt.KfacOptimizer(
- learning_rate=tf.train.exponential_decay(
- 0.00002, global_step, 10000, 0.5, staircase=True),
- cov_ema_decay=0.95,
- damping=0.0005,
- layer_collection=layer_collection,
- momentum=0.99,
- placement_strategy="round_robin",
- cov_devices=devices,
- inv_devices=devices)
-
- (cov_update_thunks,
- inv_update_thunks) = optimizer.make_vars_and_create_op_thunks()
-
- def make_update_op(update_thunks):
- update_ops = [thunk() for thunk in update_thunks]
- return tf.group(*update_ops)
-
- # TODO(b/78537047): change (some) examples to use PeriodicInvCovUpdateKfacOpt
- # once that gets moved over? Could still leave more advanced examples as they
- # are (e.g. train_mnist_estimator in this file)
-
- cov_update_op = make_update_op(cov_update_thunks)
- with tf.control_dependencies([cov_update_op]):
- # We update the inverses only every 20 iterations.
- inverse_op = tf.cond(
- tf.equal(tf.mod(global_step, 100), 0),
- lambda: make_update_op(inv_update_thunks), tf.no_op)
- with tf.control_dependencies([inverse_op]):
- train_op = optimizer.minimize(loss, global_step=global_step)
-
- tf.logging.info("Starting training.")
- with tf.train.MonitoredTrainingSession(config=session_config) as sess:
- while not sess.should_stop():
- global_step_, loss_, accuracy_, _ = sess.run(
- [global_step, loss, accuracy, train_op])
-
- if global_step_ % 100 == 0:
- tf.logging.info("global_step: %d | loss: %f | accuracy: %f",
- global_step_, loss_, accuracy_)
-
- return accuracy_
-
-
-def train_mnist(data_dir, num_epochs, use_fake_data=False):
- """Train an MLP on MNIST.
-
- Args:
- data_dir: string. Directory to read MNIST examples from.
- num_epochs: int. Number of passes to make over the training set.
- use_fake_data: bool. If True, generate a synthetic dataset.
-
- Returns:
- accuracy of model on the final minibatch of training data.
- """
- # Load a dataset.
- tf.logging.info("Loading MNIST into memory.")
- examples, labels = mnist.load_mnist(
- data_dir,
- num_epochs=num_epochs,
- batch_size=64,
- flatten_images=True,
- use_fake_data=use_fake_data)
-
- # Build an MLP. The model's layers will be added to the LayerCollection.
- tf.logging.info("Building model.")
- layer_collection = lc.LayerCollection()
- loss, accuracy = build_model(examples, labels, 10, layer_collection)
-
- # Fit model.
- minimize(loss, accuracy, layer_collection, 1)
-
-
-def train_mnist_multitower(data_dir,
- num_epochs,
- num_towers,
- use_fake_data=False):
- """Train an MLP on MNIST, splitting the minibatch across multiple towers.
-
- Args:
- data_dir: string. Directory to read MNIST examples from.
- num_epochs: int. Number of passes to make over the training set.
- num_towers: int. Number of CPUs to split minibatch across.
- use_fake_data: bool. If True, generate a synthetic dataset.
-
- Returns:
- accuracy of model on the final minibatch of training data.
- """
- # Load a dataset.
- tower_batch_size = 64
- batch_size = tower_batch_size * num_towers
- tf.logging.info(
- ("Loading MNIST into memory. Using batch_size = %d = %d towers * %d "
- "tower batch size.") % (batch_size, num_towers, tower_batch_size))
- examples, labels = mnist.load_mnist(
- data_dir,
- num_epochs=num_epochs,
- batch_size=batch_size,
- flatten_images=True,
- use_fake_data=use_fake_data)
-
- # Split minibatch across towers.
- examples = tf.split(examples, num_towers)
- labels = tf.split(labels, num_towers)
-
- # Build an MLP. Each tower's layers will be added to the LayerCollection.
- layer_collection = lc.LayerCollection()
- tower_results = []
- for tower_id in range(num_towers):
- with tf.device("/cpu:%d" % tower_id):
- with tf.name_scope("tower%d" % tower_id):
- with tf.variable_scope(tf.get_variable_scope(), reuse=(tower_id > 0)):
- tf.logging.info("Building tower %d." % tower_id)
- tower_results.append(
- build_model(examples[tower_id], labels[tower_id], 10,
- layer_collection))
- losses, accuracies = zip(*tower_results)
-
- # Average across towers.
- loss = tf.reduce_mean(losses)
- accuracy = tf.reduce_mean(accuracies)
-
- # Fit model.
- session_config = tf.ConfigProto(
- allow_soft_placement=False, device_count={
- "CPU": num_towers
- })
- return minimize(
- loss, accuracy, layer_collection, num_towers,
- session_config=session_config)
-
-
-def train_mnist_estimator(data_dir, num_epochs, use_fake_data=False):
- """Train an MLP on MNIST using tf.estimator.
-
- Args:
- data_dir: string. Directory to read MNIST examples from.
- num_epochs: int. Number of passes to make over the training set.
- use_fake_data: bool. If True, generate a synthetic dataset.
-
- Returns:
- accuracy of model on the final minibatch of training data.
- """
-
- # Load a dataset.
- def input_fn():
- tf.logging.info("Loading MNIST into memory.")
- return mnist.load_mnist(
- data_dir,
- num_epochs=num_epochs,
- batch_size=64,
- flatten_images=True,
- use_fake_data=use_fake_data)
-
- def model_fn(features, labels, mode, params):
- """Model function for MLP trained with K-FAC.
-
- Args:
- features: Tensor of shape [batch_size, input_size]. Input features.
- labels: Tensor of shape [batch_size]. Target labels for training.
- mode: tf.estimator.ModeKey. Must be TRAIN.
- params: ignored.
-
- Returns:
- EstimatorSpec for training.
-
- Raises:
- ValueError: If 'mode' is anything other than TRAIN.
- """
- del params
-
- if mode != tf.estimator.ModeKeys.TRAIN:
- raise ValueError("Only training is supposed with this API.")
-
- # Build a ConvNet.
- layer_collection = lc.LayerCollection()
- loss, accuracy = build_model(
- features, labels, num_labels=10, layer_collection=layer_collection)
-
- # Train with K-FAC.
- global_step = tf.train.get_or_create_global_step()
- optimizer = opt.KfacOptimizer(
- learning_rate=tf.train.exponential_decay(
- 0.00002, global_step, 10000, 0.5, staircase=True),
- cov_ema_decay=0.95,
- damping=0.0001,
- layer_collection=layer_collection,
- momentum=0.99)
-
- (cov_update_thunks,
- inv_update_thunks) = optimizer.make_vars_and_create_op_thunks()
-
- def make_update_op(update_thunks):
- update_ops = [thunk() for thunk in update_thunks]
- return tf.group(*update_ops)
-
- def make_batch_executed_op(update_thunks, batch_size=1):
- return tf.group(*tf.contrib.kfac.utils.batch_execute(
- global_step, update_thunks, batch_size=batch_size))
-
- # Run cov_update_op every step. Run 1 inv_update_ops per step.
- cov_update_op = make_update_op(cov_update_thunks)
- with tf.control_dependencies([cov_update_op]):
- # But make sure to execute all the inverse ops on the first step
- inverse_op = tf.cond(tf.equal(global_step, 0),
- lambda: make_update_op(inv_update_thunks),
- lambda: make_batch_executed_op(inv_update_thunks))
- with tf.control_dependencies([inverse_op]):
- train_op = optimizer.minimize(loss, global_step=global_step)
-
- # Print metrics every 5 sec.
- hooks = [
- tf.train.LoggingTensorHook(
- {
- "loss": loss,
- "accuracy": accuracy
- }, every_n_secs=5),
- ]
- return tf.estimator.EstimatorSpec(
- mode=mode, loss=loss, train_op=train_op, training_hooks=hooks)
-
- run_config = tf.estimator.RunConfig(
- model_dir="/tmp/mnist", save_checkpoints_steps=1, keep_checkpoint_max=100)
-
- # Train until input_fn() is empty with Estimator. This is a prerequisite for
- # TPU compatibility.
- estimator = tf.estimator.Estimator(model_fn=model_fn, config=run_config)
- estimator.train(input_fn=input_fn)
diff --git a/tensorflow/contrib/kfac/examples/mlp_mnist_main.py b/tensorflow/contrib/kfac/examples/mlp_mnist_main.py
deleted file mode 100644
index 9c34ade1d2..0000000000
--- a/tensorflow/contrib/kfac/examples/mlp_mnist_main.py
+++ /dev/null
@@ -1,64 +0,0 @@
-# Copyright 2017 The TensorFlow Authors. All Rights Reserved.
-#
-# Licensed under the Apache License, Version 2.0 (the "License");
-# you may not use this file except in compliance with the License.
-# You may obtain a copy of the License at
-#
-# http://www.apache.org/licenses/LICENSE-2.0
-#
-# Unless required by applicable law or agreed to in writing, software
-# distributed under the License is distributed on an "AS IS" BASIS,
-# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-# See the License for the specific language governing permissions and
-# limitations under the License.
-# ==============================================================================
-r"""Train an MLP on MNIST using K-FAC.
-
-See mlp.py for details.
-"""
-
-from __future__ import absolute_import
-from __future__ import division
-from __future__ import print_function
-
-import argparse
-import sys
-
-import tensorflow as tf
-
-from tensorflow.contrib.kfac.examples import mlp
-
-FLAGS = None
-
-
-def main(argv):
- _ = argv
- if FLAGS.use_estimator:
- if FLAGS.num_towers != 1:
- raise ValueError("Only 1 device supported in tf.estimator example.")
- mlp.train_mnist_estimator(FLAGS.data_dir, num_epochs=200)
- elif FLAGS.num_towers > 1:
- mlp.train_mnist_multitower(
- FLAGS.data_dir, num_epochs=200, num_towers=FLAGS.num_towers)
- else:
- mlp.train_mnist(FLAGS.data_dir, num_epochs=200)
-
-
-if __name__ == "__main__":
- parser = argparse.ArgumentParser()
- parser.add_argument(
- "--data_dir",
- type=str,
- default="/tmp/mnist",
- help="Directory to store dataset in.")
- parser.add_argument(
- "--num_towers",
- type=int,
- default=1,
- help="Number of CPUs to split minibatch across.")
- parser.add_argument(
- "--use_estimator",
- action="store_true",
- help="Use tf.estimator API to train.")
- FLAGS, unparsed = parser.parse_known_args()
- tf.app.run(main=main, argv=[sys.argv[0]] + unparsed)
diff --git a/tensorflow/contrib/kfac/examples/mnist.py b/tensorflow/contrib/kfac/examples/mnist.py
deleted file mode 100644
index 547c4ab25d..0000000000
--- a/tensorflow/contrib/kfac/examples/mnist.py
+++ /dev/null
@@ -1,69 +0,0 @@
-# Copyright 2017 The TensorFlow Authors. All Rights Reserved.
-#
-# Licensed under the Apache License, Version 2.0 (the "License");
-# you may not use this file except in compliance with the License.
-# You may obtain a copy of the License at
-#
-# http://www.apache.org/licenses/LICENSE-2.0
-#
-# Unless required by applicable law or agreed to in writing, software
-# distributed under the License is distributed on an "AS IS" BASIS,
-# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-# See the License for the specific language governing permissions and
-# limitations under the License.
-# ==============================================================================
-"""Utilities for loading MNIST into TensorFlow."""
-
-from __future__ import absolute_import
-from __future__ import division
-from __future__ import print_function
-
-import numpy as np
-import tensorflow as tf
-
-__all__ = [
- 'load_mnist',
-]
-
-
-def load_mnist(data_dir,
- num_epochs,
- batch_size,
- flatten_images=True,
- use_fake_data=False):
- """Loads MNIST dataset into memory.
-
- Args:
- data_dir: string. Directory to read MNIST examples from.
- num_epochs: int. Number of passes to make over the dataset.
- batch_size: int. Number of examples per minibatch.
- flatten_images: bool. If True, [28, 28, 1]-shaped images are flattened into
- [784]-shaped vectors.
- use_fake_data: bool. If True, generate a synthetic dataset rather than
- reading MNIST in.
-
- Returns:
- examples: Tensor of shape [batch_size, 784] if 'flatten_images' is
- True, else [batch_size, 28, 28, 1]. Each row is one example.
- Values in [0, 1].
- labels: Tensor of shape [batch_size]. Indices of integer corresponding to
- each example. Values in {0...9}.
- """
- if use_fake_data:
- rng = np.random.RandomState(42)
- num_examples = batch_size * 4
- images = rng.rand(num_examples, 28 * 28)
- if not flatten_images:
- images = np.reshape(images, [num_examples, 28, 28, 1])
- labels = rng.randint(10, size=num_examples)
- else:
- mnist_data = tf.contrib.learn.datasets.mnist.read_data_sets(
- data_dir, reshape=flatten_images)
- num_examples = len(mnist_data.train.labels)
- images = mnist_data.train.images
- labels = mnist_data.train.labels
-
- dataset = tf.data.Dataset.from_tensor_slices((np.asarray(
- images, dtype=np.float32), np.asarray(labels, dtype=np.int64)))
- return (dataset.repeat(num_epochs).shuffle(num_examples).batch(batch_size)
- .make_one_shot_iterator().get_next())
diff --git a/tensorflow/contrib/kfac/examples/tests/BUILD b/tensorflow/contrib/kfac/examples/tests/BUILD
deleted file mode 100644
index ede7f183fe..0000000000
--- a/tensorflow/contrib/kfac/examples/tests/BUILD
+++ /dev/null
@@ -1,52 +0,0 @@
-package(default_visibility = ["//visibility:private"])
-
-licenses(["notice"]) # Apache 2.0
-
-exports_files(["LICENSE"])
-
-load("//tensorflow:tensorflow.bzl", "py_test")
-
-py_test(
- name = "mlp_test",
- size = "large",
- srcs = ["mlp_test.py"],
- srcs_version = "PY2AND3",
- tags = [
- "no_pip",
- "notsan",
- ],
- deps = [
- "//tensorflow:tensorflow_py",
- "//tensorflow/contrib/kfac/examples:mlp",
- "//third_party/py/numpy",
- ],
-)
-
-py_test(
- name = "convnet_test",
- size = "large",
- srcs = ["convnet_test.py"],
- srcs_version = "PY2AND3",
- tags = [
- "no_pip",
- "notsan",
- ],
- deps = [
- "//tensorflow:tensorflow_py",
- "//tensorflow/contrib/kfac",
- "//tensorflow/contrib/kfac/examples:convnet",
- "//third_party/py/numpy",
- ],
-)
-
-py_test(
- name = "mnist_test",
- srcs = ["mnist_test.py"],
- srcs_version = "PY2AND3",
- tags = ["no_pip"],
- deps = [
- "//tensorflow:tensorflow_py",
- "//tensorflow/contrib/kfac/examples:mnist",
- "//third_party/py/numpy",
- ],
-)
diff --git a/tensorflow/contrib/kfac/examples/tests/convnet_test.py b/tensorflow/contrib/kfac/examples/tests/convnet_test.py
deleted file mode 100644
index adecda7166..0000000000
--- a/tensorflow/contrib/kfac/examples/tests/convnet_test.py
+++ /dev/null
@@ -1,166 +0,0 @@
-# Copyright 2017 The TensorFlow Authors. All Rights Reserved.
-#
-# Licensed under the Apache License, Version 2.0 (the "License");
-# you may not use this file except in compliance with the License.
-# You may obtain a copy of the License at
-#
-# http://www.apache.org/licenses/LICENSE-2.0
-#
-# Unless required by applicable law or agreed to in writing, software
-# distributed under the License is distributed on an "AS IS" BASIS,
-# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-# See the License for the specific language governing permissions and
-# limitations under the License.
-# ==============================================================================
-"""Tests for convnet.py."""
-
-from __future__ import absolute_import
-from __future__ import division
-from __future__ import print_function
-
-import numpy as np
-import tensorflow as tf
-
-from tensorflow.contrib.kfac import layer_collection as lc
-from tensorflow.contrib.kfac.examples import convnet
-
-
-class ConvNetTest(tf.test.TestCase):
-
- def testConvLayer(self):
- with tf.Graph().as_default():
- pre, act, (w, b) = convnet.conv_layer(
- layer_id=1,
- inputs=tf.zeros([5, 3, 3, 2]),
- kernel_size=3,
- out_channels=5)
- self.assertShapeEqual(np.zeros([5, 3, 3, 5]), pre)
- self.assertShapeEqual(np.zeros([5, 3, 3, 5]), act)
- self.assertShapeEqual(np.zeros([3, 3, 2, 5]), tf.convert_to_tensor(w))
- self.assertShapeEqual(np.zeros([5]), tf.convert_to_tensor(b))
- self.assertIsInstance(w, tf.Variable)
- self.assertIsInstance(b, tf.Variable)
- self.assertIn("conv_1", w.op.name)
- self.assertIn("conv_1", b.op.name)
-
- def testMaxPoolLayer(self):
- with tf.Graph().as_default():
- act = convnet.max_pool_layer(
- layer_id=1, inputs=tf.zeros([5, 6, 6, 2]), kernel_size=5, stride=3)
- self.assertShapeEqual(np.zeros([5, 2, 2, 2]), act)
- self.assertEqual(act.op.name, "pool_1/pool")
-
- def testLinearLayer(self):
- with tf.Graph().as_default():
- act, (w, b) = convnet.linear_layer(
- layer_id=1, inputs=tf.zeros([5, 20]), output_size=5)
- self.assertShapeEqual(np.zeros([5, 5]), act)
- self.assertShapeEqual(np.zeros([20, 5]), tf.convert_to_tensor(w))
- self.assertShapeEqual(np.zeros([5]), tf.convert_to_tensor(b))
- self.assertIsInstance(w, tf.Variable)
- self.assertIsInstance(b, tf.Variable)
- self.assertIn("fc_1", w.op.name)
- self.assertIn("fc_1", b.op.name)
-
- def testBuildModel(self):
- with tf.Graph().as_default():
- x = tf.placeholder(tf.float32, [None, 6, 6, 3])
- y = tf.placeholder(tf.int64, [None])
- layer_collection = lc.LayerCollection()
- loss, accuracy = convnet.build_model(
- x, y, num_labels=5, layer_collection=layer_collection)
-
- # Ensure layers and logits were registered.
- self.assertEqual(len(layer_collection.fisher_blocks), 3)
- self.assertEqual(len(layer_collection.losses), 1)
-
- # Ensure inference doesn't crash.
- with self.test_session() as sess:
- sess.run(tf.global_variables_initializer())
- feed_dict = {
- x: np.random.randn(10, 6, 6, 3).astype(np.float32),
- y: np.random.randint(5, size=10).astype(np.int64),
- }
- sess.run([loss, accuracy], feed_dict=feed_dict)
-
- def _build_toy_problem(self):
- """Construct a toy linear regression problem.
-
- Initial loss should be,
- 2.5 = 0.5 * (1^2 + 2^2)
-
- Returns:
- loss: 0-D Tensor representing loss to be minimized.
- accuracy: 0-D Tensors representing model accuracy.
- layer_collection: LayerCollection instance describing model architecture.
- """
- x = np.asarray([[1.], [2.]]).astype(np.float32)
- y = np.asarray([1., 2.]).astype(np.float32)
- x, y = (tf.data.Dataset.from_tensor_slices((x, y))
- .repeat(100).batch(2).make_one_shot_iterator().get_next())
- w = tf.get_variable("w", shape=[1, 1], initializer=tf.zeros_initializer())
- y_hat = tf.matmul(x, w)
- loss = tf.reduce_mean(0.5 * tf.square(y_hat - y))
- accuracy = loss
-
- layer_collection = lc.LayerCollection()
- layer_collection.register_fully_connected(params=w, inputs=x, outputs=y_hat)
- layer_collection.register_normal_predictive_distribution(y_hat)
-
- return loss, accuracy, layer_collection
-
- def testMinimizeLossSingleMachine(self):
- with tf.Graph().as_default():
- loss, accuracy, layer_collection = self._build_toy_problem()
- accuracy_ = convnet.minimize_loss_single_machine(
- loss, accuracy, layer_collection, device="/cpu:0")
- self.assertLess(accuracy_, 2.0)
-
- def testMinimizeLossDistributed(self):
- with tf.Graph().as_default():
- loss, accuracy, layer_collection = self._build_toy_problem()
- accuracy_ = convnet.distributed_grads_only_and_ops_chief_worker(
- task_id=0,
- is_chief=True,
- num_worker_tasks=1,
- num_ps_tasks=0,
- master="",
- checkpoint_dir=None,
- loss=loss,
- accuracy=accuracy,
- layer_collection=layer_collection)
- self.assertLess(accuracy_, 2.0)
-
- def testTrainMnistSingleMachine(self):
- with tf.Graph().as_default():
- # Ensure model training doesn't crash.
- #
- # Ideally, we should check that accuracy increases as the model converges,
- # but there are too few parameters for the model to effectively memorize
- # the training set the way an MLP can.
- convnet.train_mnist_single_machine(
- data_dir=None, num_epochs=1, use_fake_data=True, device="/cpu:0")
-
- def testTrainMnistMultitower(self):
- with tf.Graph().as_default():
- # Ensure model training doesn't crash.
- convnet.train_mnist_multitower(
- data_dir=None, num_epochs=1, num_towers=2, use_fake_data=True)
-
- def testTrainMnistDistributed(self):
- with tf.Graph().as_default():
- # Ensure model training doesn't crash.
- convnet.train_mnist_distributed_sync_replicas(
- task_id=0,
- is_chief=True,
- num_worker_tasks=1,
- num_ps_tasks=0,
- master="",
- data_dir=None,
- num_epochs=2,
- op_strategy="chief_worker",
- use_fake_data=True)
-
-
-if __name__ == "__main__":
- tf.test.main()
diff --git a/tensorflow/contrib/kfac/examples/tests/mlp_test.py b/tensorflow/contrib/kfac/examples/tests/mlp_test.py
deleted file mode 100644
index 22da6c29f1..0000000000
--- a/tensorflow/contrib/kfac/examples/tests/mlp_test.py
+++ /dev/null
@@ -1,63 +0,0 @@
-# Copyright 2017 The TensorFlow Authors. All Rights Reserved.
-#
-# Licensed under the Apache License, Version 2.0 (the "License");
-# you may not use this file except in compliance with the License.
-# You may obtain a copy of the License at
-#
-# http://www.apache.org/licenses/LICENSE-2.0
-#
-# Unless required by applicable law or agreed to in writing, software
-# distributed under the License is distributed on an "AS IS" BASIS,
-# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-# See the License for the specific language governing permissions and
-# limitations under the License.
-# ==============================================================================
-"""Tests for mlp.py."""
-
-from __future__ import absolute_import
-from __future__ import division
-from __future__ import print_function
-
-import numpy as np
-import tensorflow as tf
-
-from tensorflow.contrib.kfac.examples import mlp
-
-
-class MlpTest(tf.test.TestCase):
-
- def testFcLayer(self):
- with tf.Graph().as_default():
- pre, act, (w, b) = mlp.fc_layer(
- layer_id=1, inputs=tf.zeros([5, 3]), output_size=10)
- self.assertShapeEqual(np.zeros([5, 10]), pre)
- self.assertShapeEqual(np.zeros([5, 10]), act)
- self.assertShapeEqual(np.zeros([3, 10]), tf.convert_to_tensor(w))
- self.assertShapeEqual(np.zeros([10]), tf.convert_to_tensor(b))
- self.assertIsInstance(w, tf.Variable)
- self.assertIsInstance(b, tf.Variable)
- self.assertIn("fc_1/", w.op.name)
- self.assertIn("fc_1/", b.op.name)
-
- def testTrainMnist(self):
- with tf.Graph().as_default():
- # Ensure model training doesn't crash.
- #
- # Ideally, we should check that accuracy increases as the model converges,
- # but that takes a non-trivial amount of compute.
- mlp.train_mnist(data_dir=None, num_epochs=1, use_fake_data=True)
-
- def testTrainMnistMultitower(self):
- with tf.Graph().as_default():
- # Ensure model training doesn't crash.
- mlp.train_mnist_multitower(
- data_dir=None, num_epochs=1, num_towers=2, use_fake_data=True)
-
- def testTrainMnistEstimator(self):
- with tf.Graph().as_default():
- # Ensure model training doesn't crash.
- mlp.train_mnist_estimator(data_dir=None, num_epochs=1, use_fake_data=True)
-
-
-if __name__ == "__main__":
- tf.test.main()
diff --git a/tensorflow/contrib/kfac/examples/tests/mnist_test.py b/tensorflow/contrib/kfac/examples/tests/mnist_test.py
deleted file mode 100644
index 92f8462357..0000000000
--- a/tensorflow/contrib/kfac/examples/tests/mnist_test.py
+++ /dev/null
@@ -1,72 +0,0 @@
-# Copyright 2017 The TensorFlow Authors. All Rights Reserved.
-#
-# Licensed under the Apache License, Version 2.0 (the "License");
-# you may not use this file except in compliance with the License.
-# You may obtain a copy of the License at
-#
-# http://www.apache.org/licenses/LICENSE-2.0
-#
-# Unless required by applicable law or agreed to in writing, software
-# distributed under the License is distributed on an "AS IS" BASIS,
-# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-# See the License for the specific language governing permissions and
-# limitations under the License.
-# ==============================================================================
-"""Tests for mnist.py."""
-
-from __future__ import absolute_import
-from __future__ import division
-from __future__ import print_function
-
-import numpy as np
-import tensorflow as tf
-
-from tensorflow.contrib.kfac.examples import mnist
-
-
-class MnistTest(tf.test.TestCase):
-
- def testValues(self):
- """Ensure values are in their expected range."""
- with tf.Graph().as_default():
- examples, labels = mnist.load_mnist(
- data_dir=None, num_epochs=1, batch_size=64, use_fake_data=True)
-
- with self.test_session() as sess:
- examples_, labels_ = sess.run([examples, labels])
- self.assertTrue(np.all((0 <= examples_) & (examples_ < 1)))
- self.assertTrue(np.all((0 <= labels_) & (labels_ < 10)))
-
- def testFlattenedShapes(self):
- """Ensure images are flattened into their appropriate shape."""
- with tf.Graph().as_default():
- examples, labels = mnist.load_mnist(
- data_dir=None,
- num_epochs=1,
- batch_size=64,
- flatten_images=True,
- use_fake_data=True)
-
- with self.test_session() as sess:
- examples_, labels_ = sess.run([examples, labels])
- self.assertEqual(examples_.shape, (64, 784))
- self.assertEqual(labels_.shape, (64,))
-
- def testNotFlattenedShapes(self):
- """Ensure non-flattened images are their appropriate shape."""
- with tf.Graph().as_default():
- examples, labels = mnist.load_mnist(
- data_dir=None,
- num_epochs=1,
- batch_size=64,
- flatten_images=False,
- use_fake_data=True)
-
- with self.test_session() as sess:
- examples_, labels_ = sess.run([examples, labels])
- self.assertEqual(examples_.shape, (64, 28, 28, 1))
- self.assertEqual(labels_.shape, (64,))
-
-
-if __name__ == '__main__':
- tf.test.main()
diff --git a/tensorflow/contrib/kfac/g3doc/autoencoder.png b/tensorflow/contrib/kfac/g3doc/autoencoder.png
deleted file mode 100644
index 20f93c7703..0000000000
--- a/tensorflow/contrib/kfac/g3doc/autoencoder.png
+++ /dev/null
Binary files differ
diff --git a/tensorflow/contrib/kfac/python/kernel_tests/BUILD b/tensorflow/contrib/kfac/python/kernel_tests/BUILD
deleted file mode 100644
index 6e4a8d71ba..0000000000
--- a/tensorflow/contrib/kfac/python/kernel_tests/BUILD
+++ /dev/null
@@ -1,160 +0,0 @@
-package(default_visibility = ["//visibility:private"])
-
-licenses(["notice"]) # Apache 2.0
-
-exports_files(["LICENSE"])
-
-load("//tensorflow:tensorflow.bzl", "py_test")
-
-py_test(
- name = "estimator_test",
- srcs = ["estimator_test.py"],
- srcs_version = "PY2AND3",
- deps = [
- "//tensorflow/contrib/kfac/python/ops:fisher_estimator",
- "//tensorflow/contrib/kfac/python/ops:layer_collection",
- "//tensorflow/contrib/kfac/python/ops:utils",
- "//tensorflow/python:array_ops",
- "//tensorflow/python:client_testlib",
- "//tensorflow/python:constant_op",
- "//tensorflow/python:control_flow_ops",
- "//tensorflow/python:dtypes",
- "//tensorflow/python:framework_ops",
- "//tensorflow/python:init_ops",
- "//tensorflow/python:linalg_ops",
- "//tensorflow/python:math_ops",
- "//tensorflow/python:random_ops",
- "//tensorflow/python:training",
- "//tensorflow/python:variable_scope",
- "//tensorflow/python:variables",
- "//third_party/py/numpy",
- ],
-)
-
-py_test(
- name = "fisher_factors_test",
- srcs = ["fisher_factors_test.py"],
- srcs_version = "PY2AND3",
- deps = [
- "//tensorflow/contrib/kfac/python/ops:fisher_blocks",
- "//tensorflow/contrib/kfac/python/ops:fisher_factors",
- "//tensorflow/python:array_ops",
- "//tensorflow/python:client_testlib",
- "//tensorflow/python:constant_op",
- "//tensorflow/python:dtypes",
- "//tensorflow/python:framework_ops",
- "//tensorflow/python:gradients",
- "//tensorflow/python:math_ops",
- "//tensorflow/python:random_seed",
- "//tensorflow/python:variables",
- "//third_party/py/numpy",
- ],
-)
-
-py_test(
- name = "fisher_blocks_test",
- srcs = ["fisher_blocks_test.py"],
- srcs_version = "PY2AND3",
- deps = [
- "//tensorflow/contrib/kfac/python/ops:fisher_blocks",
- "//tensorflow/contrib/kfac/python/ops:layer_collection",
- "//tensorflow/contrib/kfac/python/ops:linear_operator",
- "//tensorflow/contrib/kfac/python/ops:utils",
- "//tensorflow/python:array_ops",
- "//tensorflow/python:client_testlib",
- "//tensorflow/python:framework_ops",
- "//tensorflow/python:math_ops",
- "//tensorflow/python:random_ops",
- "//tensorflow/python:random_seed",
- "//tensorflow/python:state_ops",
- "//tensorflow/python:variables",
- "//third_party/py/numpy",
- ],
-)
-
-py_test(
- name = "layer_collection_test",
- srcs = ["layer_collection_test.py"],
- srcs_version = "PY2AND3",
- deps = [
- "//tensorflow/contrib/kfac/python/ops:fisher_blocks",
- "//tensorflow/contrib/kfac/python/ops:fisher_factors",
- "//tensorflow/contrib/kfac/python/ops:layer_collection",
- "//tensorflow/python:array_ops",
- "//tensorflow/python:client_testlib",
- "//tensorflow/python:dtypes",
- "//tensorflow/python:framework_ops",
- "//tensorflow/python:linalg_ops",
- "//tensorflow/python:math_ops",
- "//tensorflow/python:random_ops",
- "//tensorflow/python:random_seed",
- "//tensorflow/python:variable_scope",
- ],
-)
-
-py_test(
- name = "optimizer_test",
- srcs = ["optimizer_test.py"],
- srcs_version = "PY2AND3",
- deps = [
- "//tensorflow/contrib/kfac/python/ops:fisher_factors",
- "//tensorflow/contrib/kfac/python/ops:kfac_optimizer",
- "//tensorflow/contrib/kfac/python/ops:layer_collection",
- "//tensorflow/python:array_ops",
- "//tensorflow/python:client_testlib",
- "//tensorflow/python:framework_ops",
- "//tensorflow/python:init_ops",
- "//tensorflow/python:math_ops",
- "//tensorflow/python:nn",
- "//tensorflow/python:variable_scope",
- "//tensorflow/python:variables",
- "//third_party/py/numpy",
- ],
-)
-
-py_test(
- name = "utils_test",
- srcs = ["utils_test.py"],
- srcs_version = "PY2AND3",
- tags = ["no_windows"], # TODO: needs investigation on Windows
- deps = [
- "//tensorflow/contrib/kfac/python/ops:utils",
- "//tensorflow/contrib/tpu",
- "//tensorflow/python:array_ops",
- "//tensorflow/python:client_testlib",
- "//tensorflow/python:dtypes",
- "//tensorflow/python:framework_ops",
- "//tensorflow/python:linalg_ops",
- "//tensorflow/python:random_seed",
- "//tensorflow/python:variable_scope",
- "//tensorflow/python:variables",
- "//third_party/py/numpy",
- ],
-)
-
-py_test(
- name = "op_queue_test",
- srcs = ["op_queue_test.py"],
- srcs_version = "PY2AND3",
- deps = [
- "//tensorflow/contrib/kfac/python/ops:op_queue",
- "//tensorflow/python:client_testlib",
- "//tensorflow/python:framework_ops",
- "//tensorflow/python:math_ops",
- ],
-)
-
-py_test(
- name = "loss_functions_test",
- srcs = ["loss_functions_test.py"],
- srcs_version = "PY2AND3",
- deps = [
- "//tensorflow/contrib/kfac/python/ops:loss_functions",
- "//tensorflow/python:array_ops",
- "//tensorflow/python:client_testlib",
- "//tensorflow/python:constant_op",
- "//tensorflow/python:framework_ops",
- "//tensorflow/python:random_ops",
- "//third_party/py/numpy",
- ],
-)
diff --git a/tensorflow/contrib/kfac/python/kernel_tests/estimator_test.py b/tensorflow/contrib/kfac/python/kernel_tests/estimator_test.py
deleted file mode 100644
index 76b31a5730..0000000000
--- a/tensorflow/contrib/kfac/python/kernel_tests/estimator_test.py
+++ /dev/null
@@ -1,310 +0,0 @@
-# Copyright 2017 The TensorFlow Authors. All Rights Reserved.
-#
-# Licensed under the Apache License, Version 2.0 (the "License");
-# you may not use this file except in compliance with the License.
-# You may obtain a copy of the License at
-#
-# http://www.apache.org/licenses/LICENSE-2.0
-#
-# Unless required by applicable law or agreed to in writing, software
-# distributed under the License is distributed on an "AS IS" BASIS,
-# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-# See the License for the specific language governing permissions and
-# limitations under the License.
-# ==============================================================================
-"""Tests for tf.contrib.kfac.estimator."""
-
-from __future__ import absolute_import
-from __future__ import division
-from __future__ import print_function
-
-import numpy as np
-
-from tensorflow.contrib.kfac.python.ops import estimator
-from tensorflow.contrib.kfac.python.ops import layer_collection as lc
-from tensorflow.contrib.kfac.python.ops import utils
-from tensorflow.python.framework import dtypes
-from tensorflow.python.framework import ops
-from tensorflow.python.ops import array_ops
-from tensorflow.python.ops import control_flow_ops
-from tensorflow.python.ops import init_ops
-from tensorflow.python.ops import linalg_ops
-from tensorflow.python.ops import math_ops
-from tensorflow.python.ops import random_ops
-from tensorflow.python.ops import variable_scope
-from tensorflow.python.ops import variables
-from tensorflow.python.platform import test
-from tensorflow.python.training import training_util
-
-_ALL_ESTIMATION_MODES = ["gradients", "empirical", "curvature_prop", "exact"]
-
-
-class EstimatorTest(test.TestCase):
-
- def setUp(self):
- self._graph = ops.Graph()
- with self._graph.as_default():
- self.layer_collection = lc.LayerCollection()
-
- self.inputs = random_ops.random_normal((2, 2), dtype=dtypes.float32)
- self.weights = variable_scope.get_variable(
- "w", shape=(2, 2), dtype=dtypes.float32)
- self.bias = variable_scope.get_variable(
- "b", initializer=init_ops.zeros_initializer(), shape=(2, 1))
- self.output = math_ops.matmul(self.inputs, self.weights) + self.bias
-
- # Only register the weights.
- self.layer_collection.register_fully_connected(
- params=(self.weights,), inputs=self.inputs, outputs=self.output)
-
- self.outputs = math_ops.tanh(self.output)
- self.targets = array_ops.zeros_like(self.outputs)
- self.layer_collection.register_categorical_predictive_distribution(
- logits=self.outputs, targets=self.targets)
-
- def testEstimatorInitManualRegistration(self):
- with self._graph.as_default():
- # We should be able to build an estimator for only the registered vars.
- estimator.FisherEstimatorRoundRobin(
- variables=[self.weights],
- cov_ema_decay=0.1,
- damping=0.2,
- layer_collection=self.layer_collection
- )
-
- # Check that we throw an error if we try to build an estimator for vars
- # that were not manually registered.
- with self.assertRaises(ValueError):
- est = estimator.FisherEstimatorRoundRobin(
- variables=[self.weights, self.bias],
- cov_ema_decay=0.1,
- damping=0.2,
- layer_collection=self.layer_collection
- )
- est.make_vars_and_create_op_thunks()
-
- # Check that we throw an error if we don't include registered variables,
- # i.e. self.weights
- with self.assertRaises(ValueError):
- est = estimator.FisherEstimatorRoundRobin(
- variables=[],
- cov_ema_decay=0.1,
- damping=0.2,
- layer_collection=self.layer_collection)
- est.make_vars_and_create_op_thunks()
-
- @test.mock.patch.object(utils.SubGraph, "variable_uses", return_value=42)
- def testVariableWrongNumberOfUses(self, mock_uses):
- with self.assertRaises(ValueError):
- est = estimator.FisherEstimatorRoundRobin(
- variables=[self.weights],
- cov_ema_decay=0.1,
- damping=0.2,
- layer_collection=self.layer_collection)
- est.make_vars_and_create_op_thunks()
-
- def testInvalidEstimationMode(self):
- with self.assertRaises(ValueError):
- est = estimator.FisherEstimatorRoundRobin(
- variables=[self.weights],
- cov_ema_decay=0.1,
- damping=0.2,
- layer_collection=self.layer_collection,
- estimation_mode="not_a_real_mode")
- est.make_vars_and_create_op_thunks()
-
- def testGradientsModeBuild(self):
- with self._graph.as_default():
- est = estimator.FisherEstimatorRoundRobin(
- variables=[self.weights],
- cov_ema_decay=0.1,
- damping=0.2,
- layer_collection=self.layer_collection,
- estimation_mode="gradients")
- est.make_vars_and_create_op_thunks()
-
- def testEmpiricalModeBuild(self):
- with self._graph.as_default():
- est = estimator.FisherEstimatorRoundRobin(
- variables=[self.weights],
- cov_ema_decay=0.1,
- damping=0.2,
- layer_collection=self.layer_collection,
- estimation_mode="empirical")
- est.make_vars_and_create_op_thunks()
-
- def testCurvaturePropModeBuild(self):
- with self._graph.as_default():
- est = estimator.FisherEstimatorRoundRobin(
- variables=[self.weights],
- cov_ema_decay=0.1,
- damping=0.2,
- layer_collection=self.layer_collection,
- estimation_mode="curvature_prop")
- est.make_vars_and_create_op_thunks()
-
- def testExactModeBuild(self):
- with self._graph.as_default():
- est = estimator.FisherEstimatorRoundRobin(
- variables=[self.weights],
- cov_ema_decay=0.1,
- damping=0.2,
- layer_collection=self.layer_collection,
- estimation_mode="exact")
- est.make_vars_and_create_op_thunks()
-
- def test_cov_update_thunks(self):
- """Ensures covariance update ops run once per global_step."""
- with self._graph.as_default(), self.cached_session() as sess:
- fisher_estimator = estimator.FisherEstimatorRoundRobin(
- variables=[self.weights],
- layer_collection=self.layer_collection,
- damping=0.2,
- cov_ema_decay=0.0)
-
- # Construct an op that executes one covariance update per step.
- global_step = training_util.get_or_create_global_step()
- (cov_variable_thunks, cov_update_op_thunks, _,
- _) = fisher_estimator.create_ops_and_vars_thunks()
- for thunk in cov_variable_thunks:
- thunk()
- cov_matrices = [
- fisher_factor.get_cov()
- for fisher_factor in self.layer_collection.get_factors()
- ]
- cov_update_op = control_flow_ops.case(
- [(math_ops.equal(global_step, i), thunk)
- for i, thunk in enumerate(cov_update_op_thunks)])
- increment_global_step = global_step.assign_add(1)
-
- sess.run(variables.global_variables_initializer())
- initial_cov_values = sess.run(cov_matrices)
-
- # Ensure there's one update per covariance matrix.
- self.assertEqual(len(cov_matrices), len(cov_update_op_thunks))
-
- # Test is no-op if only 1 covariance matrix.
- assert len(cov_matrices) > 1
-
- for i in range(len(cov_matrices)):
- # Compare new and old covariance values
- new_cov_values = sess.run(cov_matrices)
- is_cov_equal = [
- np.allclose(initial_cov_value, new_cov_value)
- for (initial_cov_value,
- new_cov_value) in zip(initial_cov_values, new_cov_values)
- ]
- num_cov_equal = sum(is_cov_equal)
-
- # Ensure exactly one covariance matrix changes per step.
- self.assertEqual(num_cov_equal, len(cov_matrices) - i)
-
- # Run all covariance update ops.
- sess.run(cov_update_op)
- sess.run(increment_global_step)
-
- def test_round_robin_placement(self):
- """Check if the ops and variables are placed on devices correctly."""
- with self._graph.as_default():
- fisher_estimator = estimator.FisherEstimatorRoundRobin(
- variables=[self.weights],
- layer_collection=self.layer_collection,
- damping=0.2,
- cov_ema_decay=0.0,
- cov_devices=["/cpu:{}".format(i) for i in range(2)],
- inv_devices=["/cpu:{}".format(i) for i in range(2)])
-
- # Construct an op that executes one covariance update per step.
- (cov_update_thunks,
- inv_update_thunks) = fisher_estimator.make_vars_and_create_op_thunks(
- scope="test")
- cov_update_ops = tuple(thunk() for thunk in cov_update_thunks)
- inv_update_ops = tuple(thunk() for thunk in inv_update_thunks)
- self.assertEqual(cov_update_ops[0].device, "/device:CPU:0")
- self.assertEqual(cov_update_ops[1].device, "/device:CPU:1")
- self.assertEqual(inv_update_ops[0].device, "/device:CPU:0")
- self.assertEqual(inv_update_ops[1].device, "/device:CPU:1")
- cov_matrices = [
- fisher_factor.get_cov()
- for fisher_factor in self.layer_collection.get_factors()
- ]
- inv_matrices = [
- matrix
- for fisher_factor in self.layer_collection.get_factors()
- for matrix in fisher_factor._matpower_by_exp_and_damping.values()
- ]
- self.assertEqual(cov_matrices[0].device, "/device:CPU:0")
- self.assertEqual(cov_matrices[1].device, "/device:CPU:1")
- # Inverse matrices need to be explicitly placed.
- self.assertEqual(inv_matrices[0].device, "")
- self.assertEqual(inv_matrices[1].device, "")
-
- def test_inv_update_thunks(self):
- """Ensures inverse update ops run once per global_step."""
- with self._graph.as_default(), self.cached_session() as sess:
- fisher_estimator = estimator.FisherEstimatorRoundRobin(
- variables=[self.weights],
- layer_collection=self.layer_collection,
- damping=0.2,
- cov_ema_decay=0.0)
-
- # Construct op that updates one inverse per global step.
- global_step = training_util.get_or_create_global_step()
- (cov_variable_thunks, _, inv_variable_thunks,
- inv_update_op_thunks) = fisher_estimator.create_ops_and_vars_thunks()
- for thunk in cov_variable_thunks:
- thunk()
- for thunk in inv_variable_thunks:
- thunk()
- inv_matrices = [
- matrix
- for fisher_factor in self.layer_collection.get_factors()
- for matrix in fisher_factor._matpower_by_exp_and_damping.values()
- ]
- inv_update_op = control_flow_ops.case(
- [(math_ops.equal(global_step, i), thunk)
- for i, thunk in enumerate(inv_update_op_thunks)])
- increment_global_step = global_step.assign_add(1)
-
- sess.run(variables.global_variables_initializer())
- initial_inv_values = sess.run(inv_matrices)
-
- # Ensure there's one update per inverse matrix. This is true as long as
- # there's no fan-in/fan-out or parameter re-use.
- self.assertEqual(len(inv_matrices), len(inv_update_op_thunks))
-
- # Test is no-op if only 1 invariance matrix.
- assert len(inv_matrices) > 1
-
- # Assign each covariance matrix a value other than the identity. This
- # ensures that the inverse matrices are updated to something different as
- # well.
- cov_matrices = [
- fisher_factor.get_cov()
- for fisher_factor in self.layer_collection.get_factors()
- ]
- sess.run([
- cov_matrix.assign(2 * linalg_ops.eye(int(cov_matrix.shape[0])))
- for cov_matrix in cov_matrices
- ])
-
- for i in range(len(inv_matrices)):
- # Compare new and old inverse values
- new_inv_values = sess.run(inv_matrices)
- is_inv_equal = [
- np.allclose(initial_inv_value, new_inv_value)
- for (initial_inv_value,
- new_inv_value) in zip(initial_inv_values, new_inv_values)
- ]
- num_inv_equal = sum(is_inv_equal)
-
- # Ensure exactly one inverse matrix changes per step.
- self.assertEqual(num_inv_equal, len(inv_matrices) - i)
-
- # Run all inverse update ops.
- sess.run(inv_update_op)
- sess.run(increment_global_step)
-
-
-if __name__ == "__main__":
- test.main()
diff --git a/tensorflow/contrib/kfac/python/kernel_tests/fisher_blocks_test.py b/tensorflow/contrib/kfac/python/kernel_tests/fisher_blocks_test.py
deleted file mode 100644
index f845def507..0000000000
--- a/tensorflow/contrib/kfac/python/kernel_tests/fisher_blocks_test.py
+++ /dev/null
@@ -1,1018 +0,0 @@
-# Copyright 2017 The TensorFlow Authors. All Rights Reserved.
-#
-# Licensed under the Apache License, Version 2.0 (the "License");
-# you may not use this file except in compliance with the License.
-# You may obtain a copy of the License at
-#
-# http://www.apache.org/licenses/LICENSE-2.0
-#
-# Unless required by applicable law or agreed to in writing, software
-# distributed under the License is distributed on an "AS IS" BASIS,
-# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-# See the License for the specific language governing permissions and
-# limitations under the License.
-# ==============================================================================
-"""Tests for tf.contrib.kfac.fisher_blocks."""
-
-from __future__ import absolute_import
-from __future__ import division
-from __future__ import print_function
-
-import numpy as np
-
-from tensorflow.contrib.kfac.python.ops import fisher_blocks as fb
-from tensorflow.contrib.kfac.python.ops import fisher_factors as ff
-from tensorflow.contrib.kfac.python.ops import layer_collection as lc
-from tensorflow.contrib.kfac.python.ops import linear_operator as lo
-from tensorflow.contrib.kfac.python.ops import utils
-from tensorflow.python.framework import ops
-from tensorflow.python.framework import random_seed
-from tensorflow.python.ops import array_ops
-from tensorflow.python.ops import linalg_ops
-from tensorflow.python.ops import math_ops
-from tensorflow.python.ops import random_ops
-from tensorflow.python.ops import state_ops
-from tensorflow.python.ops import variables as tf_variables
-from tensorflow.python.platform import test
-
-
-# We need to set these constants since the numerical values used in the tests
-# were chosen when these used to be the defaults.
-ff.set_global_constants(init_covariances_at_zero=False,
- zero_debias=False,
- init_inverses_at_zero=False)
-
-# TODO(b/78538100): As far as I can tell, all the tests that say "Make sure our
-# inverse is something other than the identity" are actually broken. They never
-# run the covariance update ops and so the inverse actually is the identity
-# (possible plus the damping term, which would still make it a multiple of the
-# identity).
-
-
-def _make_psd(dim):
- """Constructs a PSD matrix of the given dimension."""
- mat = np.ones((dim, dim), dtype=np.float32)
- mat[np.arange(dim), np.arange(dim)] = 2. + np.arange(dim)
- return array_ops.constant(mat)
-
-
-class UtilsTest(test.TestCase):
-
- def testComputePiTracenorm(self):
- with ops.Graph().as_default(), self.cached_session() as sess:
- random_seed.set_random_seed(200)
- diag = ops.convert_to_tensor([1., 2., 0., 1.])
- left_factor = lo.LinearOperatorDiag(diag)
- right_factor = lo.LinearOperatorFullMatrix(array_ops.ones([2, 2]))
-
- # pi is the sqrt of the left trace norm divided by the right trace norm
- pi = fb.compute_pi_tracenorm(left_factor, right_factor)
-
- pi_val = sess.run(pi)
- self.assertEqual(1., pi_val)
-
-
-class FullFBTest(test.TestCase):
-
- def testFullFBInitSingleTensor(self):
- with ops.Graph().as_default():
- random_seed.set_random_seed(200)
- params = (array_ops.constant([1., 2.]), array_ops.constant(3.))
- block = fb.FullFB(lc.LayerCollection(), params)
- block.register_additional_tower(32)
-
- self.assertAllEqual(params, block.tensors_to_compute_grads())
-
- def testFullFBInitTensorTuple(self):
- with ops.Graph().as_default():
- random_seed.set_random_seed(200)
- params = (array_ops.constant([1., 2.]), array_ops.constant(3.))
- block = fb.FullFB(lc.LayerCollection(), params)
- block.register_additional_tower(32)
-
- self.assertAllEqual(params, block.tensors_to_compute_grads())
-
- def testInstantiateFactors(self):
- with ops.Graph().as_default():
- random_seed.set_random_seed(200)
- params = (array_ops.constant([1., 2.]), array_ops.constant(3.))
- block = fb.FullFB(lc.LayerCollection(), params)
- block.register_additional_tower(32)
-
- grads = (params[0]**2, math_ops.sqrt(params[1]))
- block.instantiate_factors(grads, 0.5)
-
- def testMultiplyInverseTuple(self):
- with ops.Graph().as_default(), self.cached_session() as sess:
- random_seed.set_random_seed(200)
- params = (array_ops.constant([1., 2.]), array_ops.constant(3.))
- block = fb.FullFB(lc.LayerCollection(), params)
- block.register_additional_tower(32)
- grads = (params[0]**2, math_ops.sqrt(params[1]))
- block.instantiate_factors((grads,), 0.5)
- block._factor.instantiate_cov_variables()
- block.register_inverse()
- block._factor.instantiate_inv_variables()
-
- # Make sure our inverse is something other than the identity.
- sess.run(tf_variables.global_variables_initializer())
- sess.run(block._factor.make_inverse_update_ops())
-
- vector = array_ops.ones(3,) * 2
- output = block.multiply_inverse(vector)
-
- self.assertAllClose(sess.run(vector * 2 / 3.), sess.run(output))
-
- def testMultiplyInverseNotTuple(self):
- with ops.Graph().as_default(), self.cached_session() as sess:
- random_seed.set_random_seed(200)
- params = array_ops.constant([[1.], [2.]])
- block = fb.FullFB(lc.LayerCollection(), params)
- block.register_additional_tower(32)
- grads = params**2
- block.instantiate_factors((grads,), 0.5)
- block._factor.instantiate_cov_variables()
- block.register_inverse()
- block._factor.instantiate_inv_variables()
-
- # Make sure our inverse is something other than the identity.
- sess.run(tf_variables.global_variables_initializer())
- sess.run(block._factor.make_inverse_update_ops())
-
- vector = array_ops.ones(2,) * 2
- output = block.multiply_inverse(vector)
-
- self.assertAllClose(sess.run(vector * 2 / 3.), sess.run(output))
-
- def testMultiplyInverseAgainstExplicit(self):
- with ops.Graph().as_default(), self.cached_session() as sess:
- random_seed.set_random_seed(200)
- params = (array_ops.constant([1., 2.]), array_ops.constant(3.))
- block = fb.FullFB(lc.LayerCollection(), params)
- block.register_additional_tower(32)
- grads = (array_ops.constant([2., 3.]), array_ops.constant(4.))
- damping = 0.5
- block.instantiate_factors((grads,), damping)
- block._factor.instantiate_cov_variables()
- block.register_inverse()
- block._factor.instantiate_inv_variables()
-
- # Make sure our inverse is something other than the identity.
- sess.run(state_ops.assign(block._factor._cov, _make_psd(3)))
- sess.run(block._factor.make_inverse_update_ops())
-
- v_flat = np.array([4., 5., 6.], dtype=np.float32)
- vector = utils.column_to_tensors(params, array_ops.constant(v_flat))
- output = block.multiply_inverse(vector)
- output_flat = sess.run(utils.tensors_to_column(output)).ravel()
-
- full = sess.run(block.full_fisher_block())
- explicit = np.dot(np.linalg.inv(full + damping * np.eye(3)), v_flat)
-
- self.assertAllClose(output_flat, explicit)
-
-
-class NaiveDiagonalFBTest(test.TestCase):
-
- def testNaiveDiagonalFBInitSingleTensor(self):
- with ops.Graph().as_default():
- random_seed.set_random_seed(200)
- params = (array_ops.constant([1., 2.]), array_ops.constant(3.))
- block = fb.NaiveDiagonalFB(lc.LayerCollection(), params)
- block.register_additional_tower(32)
-
- self.assertAllEqual(params, block.tensors_to_compute_grads())
-
- def testNaiveDiagonalFBInitTensorTuple(self):
- with ops.Graph().as_default():
- random_seed.set_random_seed(200)
- params = (array_ops.constant([1., 2.]), array_ops.constant(3.))
- block = fb.NaiveDiagonalFB(lc.LayerCollection(), params)
- block.register_additional_tower(32)
-
- self.assertAllEqual(params, block.tensors_to_compute_grads())
-
- def testInstantiateFactors(self):
- with ops.Graph().as_default():
- random_seed.set_random_seed(200)
- params = (array_ops.constant([1., 2.]), array_ops.constant(3.))
- block = fb.NaiveDiagonalFB(lc.LayerCollection(), params)
- block.register_additional_tower(32)
-
- grads = (params[0]**2, math_ops.sqrt(params[1]))
- block.instantiate_factors(grads, 0.5)
-
- def testMultiplyInverseTuple(self):
- with ops.Graph().as_default(), self.cached_session() as sess:
- random_seed.set_random_seed(200)
- params = (array_ops.constant([1., 2.]), array_ops.constant(3.))
- block = fb.NaiveDiagonalFB(lc.LayerCollection(), params)
- block.register_additional_tower(32)
- grads = (params[0]**2, math_ops.sqrt(params[1]))
- block.instantiate_factors((grads,), 0.5)
- block._factor.instantiate_cov_variables()
-
- # Make sure our inverse is something other than the identity.
- sess.run(tf_variables.global_variables_initializer())
- sess.run(block._factor.make_inverse_update_ops())
-
- vector = array_ops.ones(3,) * 2
- output = block.multiply_inverse(vector)
-
- self.assertAllClose(sess.run(vector * 2 / 3.), sess.run(output))
-
- def testMultiplyInverseNotTuple(self):
- with ops.Graph().as_default(), self.cached_session() as sess:
- random_seed.set_random_seed(200)
- params = array_ops.constant([[1.], [2.]])
- block = fb.NaiveDiagonalFB(lc.LayerCollection(), params)
- block.register_additional_tower(32)
- grads = params**2
- block.instantiate_factors((grads,), 0.5)
- block._factor.instantiate_cov_variables()
-
- # Make sure our inverse is something other than the identity.
- sess.run(tf_variables.global_variables_initializer())
- sess.run(block._factor.make_inverse_update_ops())
- vector = array_ops.ones(2,) * 2
- output = block.multiply_inverse(vector)
-
- self.assertAllClose(sess.run(vector * 2 / 3.), sess.run(output))
-
- def testMultiplyInverseAgainstExplicit(self):
- with ops.Graph().as_default(), self.cached_session() as sess:
- random_seed.set_random_seed(200)
- params = (array_ops.constant([1., 2.]), array_ops.constant(3.))
- block = fb.NaiveDiagonalFB(lc.LayerCollection(), params)
- block.register_additional_tower(32)
- grads = (params[0]**2, math_ops.sqrt(params[1]))
- damping = 0.5
- block.instantiate_factors((grads,), damping)
- block._factor.instantiate_cov_variables()
-
- cov = array_ops.reshape(array_ops.constant([2., 3., 4.]), [-1, 1])
- sess.run(state_ops.assign(block._factor._cov, cov))
- sess.run(block._factor.make_inverse_update_ops())
-
- v_flat = np.array([4., 5., 6.], dtype=np.float32)
- vector = utils.column_to_tensors(params, array_ops.constant(v_flat))
- output = block.multiply_inverse(vector)
- output_flat = sess.run(utils.tensors_to_column(output)).ravel()
-
- full = sess.run(block.full_fisher_block())
- explicit = np.dot(np.linalg.inv(full + damping * np.eye(3)), v_flat)
- self.assertAllClose(output_flat, explicit)
-
-
-class FullyConnectedDiagonalFBTest(test.TestCase):
-
- def setUp(self):
- super(FullyConnectedDiagonalFBTest, self).setUp()
-
- self.batch_size = 4
- self.input_size = 6
- self.output_size = 3
-
- self.inputs = np.random.randn(self.batch_size, self.input_size).astype(
- np.float32)
- self.outputs = np.zeros([self.batch_size, self.output_size]).astype(
- np.float32)
- self.output_grads = np.random.randn(self.batch_size,
- self.output_size).astype(np.float32)
- self.w = np.random.randn(self.input_size, self.output_size).astype(
- np.float32)
- self.b = np.random.randn(self.output_size).astype(np.float32)
-
- def fisherApprox(self, has_bias=False):
- """Fisher approximation using default inputs."""
- if has_bias:
- inputs = np.concatenate(
- [self.inputs, np.ones([self.batch_size, 1])], axis=1)
- else:
- inputs = self.inputs
- return self.buildDiagonalFisherApproximation(inputs, self.output_grads)
-
- def buildDiagonalFisherApproximation(self, inputs, output_grads):
- """Builds explicit diagonal Fisher approximation.
-
- Fisher's diagonal is (d loss / d w)'s elements squared for
- d/dw = E[outer(input, output_grad)]
-
- where the expectation is taken over examples.
-
- Args:
- inputs: np.array of shape [batch_size, input_size].
- output_grads: np.array of shape [batch_size, output_size].
-
- Returns:
- Diagonal np.array of shape [num_params, num_params] for num_params =
- input_size * output_size.
- """
- batch_size = inputs.shape[0]
- assert output_grads.shape[0] == batch_size
- input_size = inputs.shape[1]
- output_size = output_grads.shape[1]
- fisher_diag = np.zeros((input_size, output_size))
- for i in range(batch_size):
- fisher_diag += np.square(np.outer(inputs[i], output_grads[i]))
- return np.diag(fisher_diag.flatten()) / batch_size
-
- def testMultiply(self):
- result, _ = self.runFisherBlockOps(self.w, [self.inputs], [self.outputs],
- [self.output_grads])
-
- # Construct Fisher-vector product.
- expected_result = self.fisherApprox().dot(self.w.flatten())
- expected_result = expected_result.reshape(
- [self.input_size, self.output_size])
-
- self.assertAllClose(expected_result, result)
-
- def testMultiplyInverse(self):
- _, result = self.runFisherBlockOps(self.w, [self.inputs], [self.outputs],
- [self.output_grads])
-
- # Construct inverse Fisher-vector product.
- expected_result = np.linalg.inv(self.fisherApprox()).dot(self.w.flatten())
- expected_result = expected_result.reshape(
- [self.input_size, self.output_size])
-
- self.assertAllClose(expected_result, result)
-
- def testRegisterAdditionalTower(self):
- """Ensure 1 big tower and 2 small towers are equivalent."""
- multiply_result_big, multiply_inverse_result_big = self.runFisherBlockOps(
- self.w, [self.inputs], [self.outputs], [self.output_grads])
- multiply_result_small, multiply_inverse_result_small = (
- self.runFisherBlockOps(self.w, np.split(self.inputs, 2),
- np.split(self.outputs, 2),
- np.split(self.output_grads, 2)))
-
- self.assertAllClose(multiply_result_big, multiply_result_small)
- self.assertAllClose(multiply_inverse_result_big,
- multiply_inverse_result_small)
-
- def testMultiplyHasBias(self):
- result, _ = self.runFisherBlockOps((self.w, self.b), [self.inputs],
- [self.outputs], [self.output_grads])
- expected_result = self.fisherApprox(True).dot(
- np.concatenate([self.w.flatten(), self.b.flatten()]))
- expected_result = expected_result.reshape(
- [self.input_size + 1, self.output_size])
- expected_result = (expected_result[:-1], expected_result[-1])
-
- self.assertEqual(len(result), 2)
- self.assertAllClose(expected_result[0], result[0])
- self.assertAllClose(expected_result[1], result[1])
-
- def runFisherBlockOps(self, params, inputs, outputs, output_grads):
- """Run Ops guaranteed by FisherBlock interface.
-
- Args:
- params: Tensor or 2-tuple of Tensors. Represents weights or weights and
- bias of this layer.
- inputs: list of Tensors of shape [batch_size, input_size]. Inputs to
- layer.
- outputs: list of Tensors of shape [batch_size, output_size].
- Preactivations produced by layer.
- output_grads: list of Tensors of shape [batch_size, output_size].
- Gradient of loss with respect to 'outputs'.
-
- Returns:
- multiply_result: Result of FisherBlock.multiply(params)
- multiply_inverse_result: Result of FisherBlock.multiply_inverse(params)
- """
- with ops.Graph().as_default(), self.cached_session() as sess:
- inputs = as_tensors(inputs)
- outputs = as_tensors(outputs)
- output_grads = as_tensors(output_grads)
- params = as_tensors(params)
-
- block = fb.FullyConnectedDiagonalFB(
- lc.LayerCollection(), has_bias=isinstance(params, (tuple, list)))
- for (i, o) in zip(inputs, outputs):
- block.register_additional_tower(i, o)
-
- block.instantiate_factors((output_grads,), damping=0.0)
- block._factor.instantiate_cov_variables()
-
- sess.run(tf_variables.global_variables_initializer())
- sess.run(block._factor.make_covariance_update_op(0.0))
- multiply_result = sess.run(block.multiply(params))
- multiply_inverse_result = sess.run(block.multiply_inverse(params))
-
- return multiply_result, multiply_inverse_result
-
-
-class EmbeddingKFACFBTest(test.TestCase):
-
- def testInstantiateFactors(self):
- with ops.Graph().as_default():
- random_seed.set_random_seed(200)
-
- # Create a Fisher Block.
- vocab_size = 5
- block = fb.EmbeddingKFACFB(lc.LayerCollection(), vocab_size)
-
- # Add some examples.
- inputs = array_ops.constant([[0, 1], [1, 2], [2, 3]])
- outputs = array_ops.constant([[0.], [1.], [2.]])
- block.register_additional_tower(inputs, outputs)
-
- # Instantiate factor's variables. Ensure it doesn't fail.
- grads = outputs**2.
- damping = array_ops.constant(0.)
- block.instantiate_factors(((grads,),), damping)
-
- def testMultiplyInverse(self):
- with ops.Graph().as_default(), self.cached_session() as sess:
- random_seed.set_random_seed(200)
-
- # Create a Fisher Block.
- vocab_size = 5
- block = fb.EmbeddingKFACFB(lc.LayerCollection(), vocab_size)
-
- # Add some examples.
- inputs = array_ops.constant([[0, 1], [1, 2], [2, 3]])
- outputs = array_ops.constant([[0.], [1.], [2.]])
- block.register_additional_tower(inputs, outputs)
-
- # Instantiate factor's variables. Ensure it doesn't fail.
- grads = outputs**2.
- damping = array_ops.constant(0.)
- block.instantiate_factors(((grads,),), damping)
- block._input_factor.instantiate_cov_variables()
- block._output_factor.instantiate_cov_variables()
- block.register_inverse()
- block._input_factor.instantiate_inv_variables()
- block._output_factor.instantiate_inv_variables()
-
- # Create a sparse update.
- indices = array_ops.constant([1, 3, 4])
- values = array_ops.constant([[1.], [1.], [1.]])
- sparse_vector = ops.IndexedSlices(
- values, indices, dense_shape=[vocab_size, 1])
- dense_vector = array_ops.reshape([0., 1., 0., 1., 1.], [vocab_size, 1])
-
- # Compare Fisher-vector product against explicit result.
- result = block.multiply_inverse(sparse_vector)
- expected_result = linalg_ops.matrix_solve(block.full_fisher_block(),
- dense_vector)
-
- sess.run(tf_variables.global_variables_initializer())
- self.assertAlmostEqual(
- sess.run(expected_result[1]), sess.run(result.values[0]))
- self.assertAlmostEqual(
- sess.run(expected_result[3]), sess.run(result.values[1]))
- self.assertAlmostEqual(
- sess.run(expected_result[4]), sess.run(result.values[2]))
-
-
-class FullyConnectedKFACBasicFBTest(test.TestCase):
-
- def testFullyConnectedKFACBasicFBInit(self):
- with ops.Graph().as_default():
- random_seed.set_random_seed(200)
- inputs = array_ops.constant([1., 2.])
- outputs = array_ops.constant([3., 4.])
- block = fb.FullyConnectedKFACBasicFB(lc.LayerCollection())
- block.register_additional_tower(inputs, outputs)
-
- self.assertAllEqual([outputs], block.tensors_to_compute_grads())
-
- def testInstantiateFactorsHasBias(self):
- with ops.Graph().as_default():
- random_seed.set_random_seed(200)
- inputs = array_ops.constant([[1., 2.], [3., 4.]])
- outputs = array_ops.constant([[3., 4.], [5., 6.]])
- block = fb.FullyConnectedKFACBasicFB(lc.LayerCollection(), has_bias=True)
- block.register_additional_tower(inputs, outputs)
-
- grads = outputs**2
- block.instantiate_factors(((grads,),), 0.5)
-
- def testInstantiateFactorsNoBias(self):
- with ops.Graph().as_default():
- random_seed.set_random_seed(200)
- inputs = array_ops.constant([[1., 2.], [3., 4.]])
- outputs = array_ops.constant([[3., 4.], [5., 6.]])
- block = fb.FullyConnectedKFACBasicFB(lc.LayerCollection(), has_bias=False)
- block.register_additional_tower(inputs, outputs)
-
- grads = outputs**2
- block.instantiate_factors(((grads,),), 0.5)
-
- def testMultiplyInverseTuple(self):
- with ops.Graph().as_default(), self.cached_session() as sess:
- random_seed.set_random_seed(200)
- inputs = array_ops.constant([[1., 2., 3.], [3., 4., 5.], [5., 6., 7.]])
- outputs = array_ops.constant([[3., 4.], [5., 6.]])
- block = fb.FullyConnectedKFACBasicFB(lc.LayerCollection(), has_bias=False)
- block.register_additional_tower(inputs, outputs)
- grads = outputs**2
- block.instantiate_factors(((grads,),), 0.5)
-
- block._input_factor.instantiate_cov_variables()
- block._output_factor.instantiate_cov_variables()
- block.register_inverse()
- block._input_factor.instantiate_inv_variables()
- block._output_factor.instantiate_inv_variables()
-
- # Make sure our inverse is something other than the identity.
- sess.run(tf_variables.global_variables_initializer())
- sess.run(block._input_factor.make_inverse_update_ops())
- sess.run(block._output_factor.make_inverse_update_ops())
-
- vector = (
- np.arange(2, 6).reshape(2, 2).astype(np.float32), #
- np.arange(1, 3).reshape(2, 1).astype(np.float32))
- output = block.multiply_inverse((array_ops.constant(vector[0]),
- array_ops.constant(vector[1])))
-
- output = sess.run(output)
- self.assertAllClose([[0.686291, 1.029437], [1.372583, 1.715729]],
- output[0])
- self.assertAllClose([0.343146, 0.686291], output[1])
-
- def testMultiplyInverseNotTuple(self):
- with ops.Graph().as_default(), self.cached_session() as sess:
- random_seed.set_random_seed(200)
- inputs = array_ops.constant([[1., 2.], [3., 4.]])
- outputs = array_ops.constant([[3., 4.], [5., 6.]])
- block = fb.FullyConnectedKFACBasicFB(lc.LayerCollection(), has_bias=False)
- block.register_additional_tower(inputs, outputs)
- grads = outputs**2
- block.instantiate_factors(((grads,),), 0.5)
- block._input_factor.instantiate_cov_variables()
- block._output_factor.instantiate_cov_variables()
- block.register_inverse()
- block._input_factor.instantiate_inv_variables()
- block._output_factor.instantiate_inv_variables()
-
- # Make sure our inverse is something other than the identity.
- sess.run(tf_variables.global_variables_initializer())
- sess.run(block._input_factor.make_inverse_update_ops())
- sess.run(block._output_factor.make_inverse_update_ops())
-
- vector = np.arange(2, 6).reshape(2, 2).astype(np.float32)
- output = block.multiply_inverse(array_ops.constant(vector))
-
- self.assertAllClose([[0.686291, 1.029437], [1.372583, 1.715729]],
- sess.run(output))
-
- def testMultiplyInverseAgainstExplicit(self):
- with ops.Graph().as_default(), self.cached_session() as sess:
- random_seed.set_random_seed(200)
- input_dim, output_dim = 3, 2
- inputs = array_ops.zeros([32, input_dim])
- outputs = array_ops.zeros([32, output_dim])
- params = array_ops.zeros([input_dim, output_dim])
- block = fb.FullyConnectedKFACBasicFB(lc.LayerCollection(), has_bias=False)
- block.register_additional_tower(inputs, outputs)
- grads = outputs**2
- damping = 0. # This test is only valid without damping.
- block.instantiate_factors(((grads,),), damping)
- block._input_factor.instantiate_cov_variables()
- block._output_factor.instantiate_cov_variables()
-
- sess.run(state_ops.assign(block._input_factor._cov, _make_psd(3)))
- sess.run(state_ops.assign(block._output_factor._cov, _make_psd(2)))
-
- block.register_inverse()
- block._input_factor.instantiate_inv_variables()
- block._output_factor.instantiate_inv_variables()
-
- sess.run(block._input_factor.make_inverse_update_ops())
- sess.run(block._output_factor.make_inverse_update_ops())
-
- v_flat = np.arange(6, dtype=np.float32)
- vector = utils.column_to_tensors(params, array_ops.constant(v_flat))
- output = block.multiply_inverse(vector)
- output_flat = sess.run(utils.tensors_to_column(output)).ravel()
-
- full = sess.run(block.full_fisher_block())
- explicit = np.dot(np.linalg.inv(full + damping * np.eye(6)), v_flat)
-
- self.assertAllClose(output_flat, explicit)
-
-
-class ConvDiagonalFBTest(test.TestCase):
-
- def setUp(self):
- super(ConvDiagonalFBTest, self).setUp()
-
- self.batch_size = 2
- self.height = 8
- self.width = 4
- self.input_channels = 6
- self.output_channels = 3
- self.kernel_size = 1
-
- self.inputs = np.random.randn(self.batch_size, self.height, self.width,
- self.input_channels).astype(np.float32)
- self.outputs = np.zeros(
- [self.batch_size, self.height, self.width,
- self.output_channels]).astype(np.float32)
- self.output_grads = np.random.randn(
- self.batch_size, self.height, self.width, self.output_channels).astype(
- np.float32)
- self.w = np.random.randn(self.kernel_size, self.kernel_size,
- self.input_channels, self.output_channels).astype(
- np.float32)
- self.b = np.random.randn(self.output_channels).astype(np.float32)
-
- def fisherApprox(self, has_bias=False):
- """Fisher approximation using default inputs."""
- if has_bias:
- inputs = np.concatenate(
- [self.inputs,
- np.ones([self.batch_size, self.height, self.width, 1])],
- axis=-1)
- else:
- inputs = self.inputs
- return self.buildDiagonalFisherApproximation(inputs, self.output_grads,
- self.kernel_size)
-
- def buildDiagonalFisherApproximation(self, inputs, output_grads, kernel_size):
- r"""Builds explicit diagonal Fisher approximation.
-
- Fisher's diagonal is (d loss / d w)'s elements squared for
- d/dw = E[\sum_{loc} outer(input_{loc}, output_grad_{loc})]
-
- where the expectation is taken over examples and the sum over (x, y)
- locations upon which the convolution is applied.
-
- Args:
- inputs: np.array of shape [batch_size, height, width, input_channels].
- output_grads: np.array of shape [batch_size, height, width,
- output_channels].
- kernel_size: int. height and width of kernel.
-
- Returns:
- Diagonal np.array of shape [num_params, num_params] for num_params =
- kernel_size^2 * input_channels * output_channels.
- """
- batch_size, height, width, input_channels = inputs.shape
- assert output_grads.shape[0] == batch_size
- assert output_grads.shape[1] == height
- assert output_grads.shape[2] == width
- output_channels = output_grads.shape[3]
-
- # If kernel_size == 1, then we don't need to worry about capturing context
- # around the pixel upon which a convolution is applied. This makes testing
- # easier.
- assert kernel_size == 1, "kernel_size != 1 isn't supported."
- num_locations = height * width
- inputs = np.reshape(inputs, [batch_size, num_locations, input_channels])
- output_grads = np.reshape(output_grads,
- [batch_size, num_locations, output_channels])
-
- fisher_diag = np.zeros((input_channels, output_channels))
- for i in range(batch_size):
- # Each example's approximation is a square(sum-of-outer-products).
- example_fisher_diag = np.zeros((input_channels, output_channels))
- for j in range(num_locations):
- example_fisher_diag += np.outer(inputs[i, j], output_grads[i, j])
- fisher_diag += np.square(example_fisher_diag)
-
- # Normalize by batch_size (not num_locations).
- return np.diag(fisher_diag.flatten()) / batch_size
-
- def testMultiply(self):
- result, _ = self.runFisherBlockOps(self.w, [self.inputs], [self.outputs],
- [self.output_grads])
-
- # Construct Fisher-vector product.
- expected_result = self.fisherApprox().dot(self.w.flatten())
- expected_result = expected_result.reshape([
- self.kernel_size, self.kernel_size, self.input_channels,
- self.output_channels
- ])
-
- self.assertAllClose(expected_result, result)
-
- def testMultiplyInverse(self):
- _, result = self.runFisherBlockOps(self.w, [self.inputs], [self.outputs],
- [self.output_grads])
-
- # Construct inverse Fisher-vector product.
- expected_result = np.linalg.inv(self.fisherApprox()).dot(self.w.flatten())
- expected_result = expected_result.reshape([
- self.kernel_size, self.kernel_size, self.input_channels,
- self.output_channels
- ])
-
- self.assertAllClose(expected_result, result, atol=1e-3)
-
- def testRegisterAdditionalTower(self):
- """Ensure 1 big tower and 2 small towers are equivalent."""
- multiply_result_big, multiply_inverse_result_big = self.runFisherBlockOps(
- self.w, [self.inputs], [self.outputs], [self.output_grads])
- multiply_result_small, multiply_inverse_result_small = (
- self.runFisherBlockOps(self.w, np.split(self.inputs, 2),
- np.split(self.outputs, 2),
- np.split(self.output_grads, 2)))
-
- self.assertAllClose(multiply_result_big, multiply_result_small)
- self.assertAllClose(multiply_inverse_result_big,
- multiply_inverse_result_small)
-
- def testMultiplyHasBias(self):
- result, _ = self.runFisherBlockOps((self.w, self.b), [self.inputs],
- [self.outputs], [self.output_grads])
- # Clone 'b' along 'input_channels' dimension.
- b_filter = np.tile(
- np.reshape(self.b, [1, 1, 1, self.output_channels]),
- [self.kernel_size, self.kernel_size, 1, 1])
- params = np.concatenate([self.w, b_filter], axis=2)
- expected_result = self.fisherApprox(True).dot(params.flatten())
-
- # Extract 'b' from concatenated parameters.
- expected_result = expected_result.reshape([
- self.kernel_size, self.kernel_size, self.input_channels + 1,
- self.output_channels
- ])
- expected_result = (expected_result[:, :, 0:-1, :],
- np.reshape(expected_result[:, :, -1, :],
- [self.output_channels]))
-
- self.assertEqual(len(result), 2)
- self.assertAllClose(expected_result[0], result[0])
- self.assertAllClose(expected_result[1], result[1])
-
- def runFisherBlockOps(self, params, inputs, outputs, output_grads):
- """Run Ops guaranteed by FisherBlock interface.
-
- Args:
- params: Tensor or 2-tuple of Tensors. Represents weights or weights and
- bias of this layer.
- inputs: list of Tensors of shape [batch_size, input_size]. Inputs to
- layer.
- outputs: list of Tensors of shape [batch_size, output_size].
- Preactivations produced by layer.
- output_grads: list of Tensors of shape [batch_size, output_size].
- Gradient of loss with respect to 'outputs'.
-
- Returns:
- multiply_result: Result of FisherBlock.multiply(params)
- multiply_inverse_result: Result of FisherBlock.multiply_inverse(params)
- """
- with ops.Graph().as_default(), self.cached_session() as sess:
- inputs = as_tensors(inputs)
- outputs = as_tensors(outputs)
- output_grads = as_tensors(output_grads)
- params = as_tensors(params)
-
- block = fb.ConvDiagonalFB(
- lc.LayerCollection(), params, strides=[1, 1, 1, 1], padding='SAME')
- for (i, o) in zip(inputs, outputs):
- block.register_additional_tower(i, o)
-
- block.instantiate_factors((output_grads,), damping=0.0)
- block._factor.instantiate_cov_variables()
-
- sess.run(tf_variables.global_variables_initializer())
- sess.run(block._factor.make_covariance_update_op(0.0))
- multiply_result = sess.run(block.multiply(params))
- multiply_inverse_result = sess.run(block.multiply_inverse(params))
-
- return multiply_result, multiply_inverse_result
-
-
-class DepthwiseConvKFCBasicFBTest(test.TestCase):
-
- def testInstantiateFactors(self):
- with ops.Graph().as_default():
- random_seed.set_random_seed(200)
- params = random_ops.random_normal((3, 3, 8, 2))
- inputs = random_ops.random_normal((32, 5, 5, 8))
- outputs = random_ops.random_normal((32, 5, 5, 16))
- layer_collection = lc.LayerCollection()
- block = fb.DepthwiseConvKFCBasicFB(
- layer_collection, params=params, strides=[1, 1, 1, 1], padding='SAME')
- block.register_additional_tower(inputs, outputs)
- grads = outputs**2
- block.instantiate_factors(([grads],), 0.5)
-
- def testMultiplyInverse(self):
- with ops.Graph().as_default(), self.cached_session() as sess:
- random_seed.set_random_seed(200)
- params = random_ops.random_normal((3, 3, 8, 2))
- inputs = random_ops.random_normal((32, 5, 5, 8))
- outputs = random_ops.random_normal((32, 5, 5, 16))
- layer_collection = lc.LayerCollection()
- block = fb.DepthwiseConvKFCBasicFB(
- layer_collection, params=params, strides=[1, 1, 1, 1], padding='SAME')
- block.register_additional_tower(inputs, outputs)
- grads = outputs**2
- block.instantiate_factors(([grads],), 0.5)
- block._input_factor.instantiate_cov_variables()
- block._output_factor.instantiate_cov_variables()
- block.register_inverse()
- block._input_factor.instantiate_inv_variables()
- block._output_factor.instantiate_inv_variables()
-
- # Ensure inverse update op doesn't crash.
- sess.run(tf_variables.global_variables_initializer())
- sess.run([
- factor.make_inverse_update_ops()
- for factor in layer_collection.get_factors()
- ])
-
- # Ensure inverse-vector multiply doesn't crash.
- output = block.multiply_inverse(params)
- sess.run(output)
-
- # Ensure same shape.
- self.assertAllEqual(output.shape, params.shape)
-
-
-class ConvKFCBasicFBTest(test.TestCase):
-
- def _testConvKFCBasicFBInitParams(self, params):
- with ops.Graph().as_default():
- random_seed.set_random_seed(200)
- if isinstance(params, (list, tuple)):
- params = [array_ops.constant(param) for param in params]
- else:
- params = array_ops.constant(params)
- inputs = random_ops.random_normal((2, 2, 2))
- outputs = random_ops.random_normal((2, 2, 2))
- block = fb.ConvKFCBasicFB(
- lc.LayerCollection(), params=params, padding='SAME')
- block.register_additional_tower(inputs, outputs)
-
- self.assertAllEqual([outputs], block.tensors_to_compute_grads())
-
- def testConvKFCBasicFBInitParamsParamsTuple(self):
- self._testConvKFCBasicFBInitParams([np.ones([1, 2, 2]), np.ones([2])])
-
- def testConvKFCBasicFBInitParamsParamsSingle(self):
- self._testConvKFCBasicFBInitParams([np.ones([1, 2, 2])])
-
- def testMultiplyInverseTuple(self):
- with ops.Graph().as_default(), self.cached_session() as sess:
- random_seed.set_random_seed(200)
- params = random_ops.random_normal((2, 2, 2, 2))
- inputs = random_ops.random_normal((2, 2, 2, 2))
- outputs = random_ops.random_normal((2, 2, 2, 2))
- block = fb.ConvKFCBasicFB(
- lc.LayerCollection(), params=params, padding='SAME')
- block.register_additional_tower(inputs, outputs)
- grads = outputs**2
- block.instantiate_factors(((grads,),), 0.5)
- block._input_factor.instantiate_cov_variables()
- block._output_factor.instantiate_cov_variables()
- block.register_inverse()
- block._input_factor.instantiate_inv_variables()
- block._output_factor.instantiate_inv_variables()
-
- # Make sure our inverse is something other than the identity.
- sess.run(tf_variables.global_variables_initializer())
- sess.run(block._input_factor.make_inverse_update_ops())
- sess.run(block._output_factor.make_inverse_update_ops())
-
- vector = (np.arange(1, 15).reshape(7, 2).astype(np.float32),
- np.arange(2, 4).reshape(2, 1).astype(np.float32))
- output = block.multiply_inverse((array_ops.constant(vector[0]),
- array_ops.constant(vector[1])))
-
- output = sess.run(output)
- self.assertAllClose([0.136455, 0.27291], output[0][0])
- self.assertAllClose([0.27291, 0.409365], output[1])
-
- def testMultiplyInverseNotTuple(self):
- with ops.Graph().as_default(), self.cached_session() as sess:
- random_seed.set_random_seed(200)
- params = random_ops.random_normal((2, 2, 2, 2))
- inputs = random_ops.random_normal((2, 2, 2, 2))
- outputs = random_ops.random_normal((2, 2, 2, 2))
- block = fb.ConvKFCBasicFB(
- lc.LayerCollection(), params=params, padding='SAME')
- block.register_additional_tower(inputs, outputs)
- self.assertFalse(block._has_bias)
- grads = outputs**2
- block.instantiate_factors(((grads,),), 0.5)
- block._input_factor.instantiate_cov_variables()
- block._output_factor.instantiate_cov_variables()
- block.register_inverse()
- block._input_factor.instantiate_inv_variables()
- block._output_factor.instantiate_inv_variables()
-
- # Make sure our inverse is something other than the identity.
- sess.run(tf_variables.global_variables_initializer())
- sess.run(block._input_factor.make_inverse_update_ops())
- sess.run(block._output_factor.make_inverse_update_ops())
-
- vector = np.arange(1, 17).reshape(8, 2).astype(np.float32)
- output = block.multiply_inverse(array_ops.constant(vector))
-
- self.assertAllClose([0.136455, 0.27291], sess.run(output)[0])
-
- def testMultiplyInverseNotTupleWithBias(self):
- with ops.Graph().as_default(), self.cached_session() as sess:
- random_seed.set_random_seed(200)
- params = [random_ops.random_normal((2, 2, 2, 2))]
- inputs = random_ops.random_normal((2, 2, 2, 2))
- outputs = random_ops.random_normal((2, 2, 2, 2))
- block = fb.ConvKFCBasicFB(
- lc.LayerCollection(), params=params, padding='SAME')
- block.register_additional_tower(inputs, outputs)
- self.assertTrue(block._has_bias)
- grads = outputs**2
- block.instantiate_factors(((grads,),), 0.5)
- block._input_factor.instantiate_cov_variables()
- block._output_factor.instantiate_cov_variables()
- block.register_inverse()
- block._input_factor.instantiate_inv_variables()
- block._output_factor.instantiate_inv_variables()
-
- # Make sure our inverse is something other than the identity.
- sess.run(tf_variables.global_variables_initializer())
- sess.run(block._input_factor.make_inverse_update_ops())
- sess.run(block._output_factor.make_inverse_update_ops())
-
- vector = np.arange(1, 19).reshape(9, 2).astype(np.float32)
- output = block.multiply_inverse(array_ops.constant(vector))
-
- self.assertAllClose([0.136455, 0.27291], sess.run(output)[0])
-
- def testMultiplyInverseAgainstExplicit(self):
- with ops.Graph().as_default(), self.cached_session() as sess:
- random_seed.set_random_seed(200)
- params = array_ops.zeros((2, 2, 2, 2))
- inputs = array_ops.zeros((2, 2, 2, 2))
- outputs = array_ops.zeros((2, 2, 2, 2))
- block = fb.ConvKFCBasicFB(
- lc.LayerCollection(), params=params, padding='SAME')
- block.register_additional_tower(inputs, outputs)
- grads = outputs**2
- damping = 0. # This test is only valid without damping.
- block.instantiate_factors(((grads,),), damping)
- block._input_factor.instantiate_cov_variables()
- block._output_factor.instantiate_cov_variables()
- block.register_inverse()
- block._input_factor.instantiate_inv_variables()
- block._output_factor.instantiate_inv_variables()
-
- sess.run(state_ops.assign(block._input_factor._cov, _make_psd(8)))
- sess.run(state_ops.assign(block._output_factor._cov, _make_psd(2)))
- sess.run(block._input_factor.make_inverse_update_ops())
- sess.run(block._output_factor.make_inverse_update_ops())
-
- v_flat = np.arange(16, dtype=np.float32)
- vector = utils.column_to_tensors(params, array_ops.constant(v_flat))
- output = block.multiply_inverse(vector)
- output_flat = sess.run(utils.tensors_to_column(output)).ravel()
-
- full = sess.run(block.full_fisher_block())
- explicit = np.dot(np.linalg.inv(full + damping * np.eye(16)), v_flat)
-
- self.assertAllClose(output_flat, explicit)
-
-
-class FullyConnectedSeriesFBTest(test.TestCase):
-
- def testFullyConnectedSeriesFBInit(self):
- with ops.Graph().as_default():
- random_seed.set_random_seed(200)
- inputs = array_ops.constant([1., 2.])
- outputs = array_ops.constant([3., 4.])
- block = fb.FullyConnectedSeriesFB(lc.LayerCollection())
- block.register_additional_tower([inputs], [outputs])
- self.assertAllEqual([[outputs]], block.tensors_to_compute_grads())
-
- def testInstantiateFactorsHasBias(self):
- with ops.Graph().as_default():
- random_seed.set_random_seed(200)
- inputs = array_ops.constant([[1., 2.], [3., 4.]])
- outputs = array_ops.constant([[3., 4.], [5., 6.]])
- block = fb.FullyConnectedSeriesFB(
- lc.LayerCollection(),
- has_bias=True)
- block.register_additional_tower([inputs], [outputs])
- grads = outputs**2
- block.instantiate_factors((((grads,),),), 0.5)
-
- def testInstantiateFactorsNoBias(self):
- with ops.Graph().as_default():
- random_seed.set_random_seed(200)
- inputs = array_ops.constant([[1., 2.], [3., 4.]])
- outputs = array_ops.constant([[3., 4.], [5., 6.]])
- block = fb.FullyConnectedSeriesFB(
- lc.LayerCollection(),
- has_bias=False)
- block.register_additional_tower([inputs], [outputs])
- grads = outputs**2
- block.instantiate_factors((((grads,),),), 0.5)
-
-
-def as_tensors(tensor_or_tuple):
- """Converts a potentially nested tuple of np.array to Tensors."""
- if isinstance(tensor_or_tuple, (tuple, list)):
- return tuple(as_tensors(t) for t in tensor_or_tuple)
- return ops.convert_to_tensor(tensor_or_tuple)
-
-
-if __name__ == '__main__':
- test.main()
diff --git a/tensorflow/contrib/kfac/python/kernel_tests/fisher_factors_test.py b/tensorflow/contrib/kfac/python/kernel_tests/fisher_factors_test.py
deleted file mode 100644
index a396ca3f85..0000000000
--- a/tensorflow/contrib/kfac/python/kernel_tests/fisher_factors_test.py
+++ /dev/null
@@ -1,955 +0,0 @@
-# Copyright 2017 The TensorFlow Authors. All Rights Reserved.
-#
-# Licensed under the Apache License, Version 2.0 (the "License");
-# you may not use this file except in compliance with the License.
-# You may obtain a copy of the License at
-#
-# http://www.apache.org/licenses/LICENSE-2.0
-#
-# Unless required by applicable law or agreed to in writing, software
-# distributed under the License is distributed on an "AS IS" BASIS,
-# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-# See the License for the specific language governing permissions and
-# limitations under the License.
-# ==============================================================================
-"""Tests for tf.contrib.kfac.fisher_factors."""
-
-from __future__ import absolute_import
-from __future__ import division
-from __future__ import print_function
-
-import numpy as np
-import numpy.random as npr
-
-from tensorflow.contrib.kfac.python.ops import fisher_blocks as fb
-from tensorflow.contrib.kfac.python.ops import fisher_factors as ff
-from tensorflow.python.framework import constant_op
-from tensorflow.python.framework import dtypes
-from tensorflow.python.framework import ops as tf_ops
-from tensorflow.python.framework import random_seed
-from tensorflow.python.ops import array_ops
-from tensorflow.python.ops import gradients_impl
-from tensorflow.python.ops import math_ops
-from tensorflow.python.ops import random_ops
-from tensorflow.python.ops import variables as tf_variables
-from tensorflow.python.platform import test
-
-
-# We need to set these constants since the numerical values used in the tests
-# were chosen when these used to be the defaults.
-ff.set_global_constants(init_covariances_at_zero=False,
- zero_debias=False,
- init_inverses_at_zero=False)
-
-
-def make_damping_func(damping):
- return fb._package_func(lambda: damping, damping)
-
-
-class FisherFactorTestingDummy(ff.FisherFactor):
- """Dummy class to test the non-abstract methods on ff.FisherFactor."""
-
- @property
- def _var_scope(self):
- return 'dummy/a_b_c'
-
- @property
- def _cov_shape(self):
- raise NotImplementedError
-
- @property
- def _num_sources(self):
- return 1
-
- @property
- def _dtype(self):
- return dtypes.float32
-
- def _compute_new_cov(self):
- raise NotImplementedError
-
- def instantiate_covariance(self):
- pass
-
- def make_inverse_update_ops(self):
- return []
-
- def get_cov(self):
- return NotImplementedError
-
- def instantiate_inv_variables(self):
- return NotImplementedError
-
- def _num_towers(self):
- raise NotImplementedError
-
- def _get_data_device(self):
- raise NotImplementedError
-
- def register_matpower(self, exp, damping_func):
- raise NotImplementedError
-
- def register_cholesky(self, damping_func):
- raise NotImplementedError
-
- def register_cholesky_inverse(self, damping_func):
- raise NotImplementedError
-
- def get_matpower(self, exp, damping_func):
- raise NotImplementedError
-
- def get_cholesky(self, damping_func):
- raise NotImplementedError
-
- def get_cholesky_inverse(self, damping_func):
- raise NotImplementedError
-
- def get_cov_as_linear_operator(self):
- raise NotImplementedError
-
-
-class DenseSquareMatrixFactorTestingDummy(ff.DenseSquareMatrixFactor):
- """Dummy class to test the non-abstract methods on ff.DenseSquareMatrixFactor.
- """
-
- def __init__(self, shape):
- self._shape = shape
- super(DenseSquareMatrixFactorTestingDummy, self).__init__()
-
- @property
- def _var_scope(self):
- return 'dummy/a_b_c'
-
- @property
- def _cov_shape(self):
- return self._shape
-
- @property
- def _num_sources(self):
- return 1
-
- @property
- def _dtype(self):
- return dtypes.float32
-
- def _compute_new_cov(self):
- raise NotImplementedError
-
- def instantiate_covariance(self):
- pass
-
- def _num_towers(self):
- raise NotImplementedError
-
- def _get_data_device(self):
- raise NotImplementedError
-
-
-class NumericalUtilsTest(test.TestCase):
-
- def testComputeCovAgainstNumpy(self):
- with tf_ops.Graph().as_default(), self.cached_session() as sess:
- npr.seed(0)
- random_seed.set_random_seed(200)
-
- x = npr.randn(100, 3)
- cov = ff.compute_cov(array_ops.constant(x))
- np_cov = np.dot(x.T, x) / x.shape[0]
-
- self.assertAllClose(sess.run(cov), np_cov)
-
- def testComputeCovAgainstNumpyWithAlternativeNormalizer(self):
- with tf_ops.Graph().as_default(), self.cached_session() as sess:
- npr.seed(0)
- random_seed.set_random_seed(200)
-
- normalizer = 10.
- x = npr.randn(100, 3)
- cov = ff.compute_cov(array_ops.constant(x), normalizer=normalizer)
- np_cov = np.dot(x.T, x) / normalizer
-
- self.assertAllClose(sess.run(cov), np_cov)
-
- def testAppendHomog(self):
- with tf_ops.Graph().as_default(), self.cached_session() as sess:
- npr.seed(0)
-
- m, n = 3, 4
- a = npr.randn(m, n)
- a_homog = ff.append_homog(array_ops.constant(a))
- np_result = np.hstack([a, np.ones((m, 1))])
-
- self.assertAllClose(sess.run(a_homog), np_result)
-
-
-class NameStringUtilFunctionTest(test.TestCase):
-
- def _make_tensor(self):
- x = array_ops.placeholder(dtypes.float64, (3, 1))
- w = array_ops.constant(npr.RandomState(0).randn(3, 3))
- y = math_ops.matmul(w, x)
- g = gradients_impl.gradients(y, x)[0]
- return g
-
- def testScopeStringFromParamsSingleTensor(self):
- with tf_ops.Graph().as_default():
- g = self._make_tensor()
- scope_string = ff.scope_string_from_params(g)
- self.assertEqual('gradients_MatMul_grad_MatMul_1', scope_string)
-
- def testScopeStringFromParamsMultipleTensors(self):
- with tf_ops.Graph().as_default():
- x = array_ops.constant(1,)
- y = array_ops.constant(2,)
- scope_string = ff.scope_string_from_params((x, y))
- self.assertEqual('Const_Const_1', scope_string)
-
- def testScopeStringFromParamsMultipleTypes(self):
- with tf_ops.Graph().as_default():
- x = array_ops.constant(1,)
- y = array_ops.constant(2,)
- scope_string = ff.scope_string_from_params([[1, 2, 3], 'foo', True, 4,
- (x, y)])
- self.assertEqual('1-2-3_foo_True_4_Const__Const_1', scope_string)
-
- def testScopeStringFromParamsUnsupportedType(self):
- with tf_ops.Graph().as_default():
- x = array_ops.constant(1,)
- y = array_ops.constant(2,)
- unsupported = 1.2 # Floats are not supported.
- with self.assertRaises(ValueError):
- ff.scope_string_from_params([[1, 2, 3], 'foo', True, 4, (x, y),
- unsupported])
-
- def testScopeStringFromName(self):
- with tf_ops.Graph().as_default():
- g = self._make_tensor()
- scope_string = ff.scope_string_from_name(g)
- self.assertEqual('gradients_MatMul_grad_MatMul_1', scope_string)
-
- def testScalarOrTensorToString(self):
- with tf_ops.Graph().as_default():
- self.assertEqual(ff.scalar_or_tensor_to_string(5.), repr(5.))
-
- g = self._make_tensor()
- scope_string = ff.scope_string_from_name(g)
- self.assertEqual(ff.scalar_or_tensor_to_string(g), scope_string)
-
-
-class FisherFactorTest(test.TestCase):
-
- def testMakeInverseUpdateOps(self):
- with tf_ops.Graph().as_default():
- random_seed.set_random_seed(200)
- factor = FisherFactorTestingDummy()
-
- self.assertEqual(0, len(factor.make_inverse_update_ops()))
-
-
-class DenseSquareMatrixFactorTest(test.TestCase):
-
- def testRegisterDampedInverse(self):
- with tf_ops.Graph().as_default():
- random_seed.set_random_seed(200)
- shape = [2, 2]
- factor = DenseSquareMatrixFactorTestingDummy(shape)
- factor_var_scope = 'dummy/a_b_c'
-
- damping_funcs = [make_damping_func(0.1),
- make_damping_func(0.1),
- make_damping_func(1e-5),
- make_damping_func(1e-5)]
- for damping_func in damping_funcs:
- factor.register_inverse(damping_func)
-
- factor.instantiate_inv_variables()
-
- inv = factor.get_inverse(damping_funcs[0]).to_dense()
- self.assertEqual(inv, factor.get_inverse(damping_funcs[1]).to_dense())
- self.assertNotEqual(inv, factor.get_inverse(damping_funcs[2]).to_dense())
- self.assertEqual(factor.get_inverse(damping_funcs[2]).to_dense(),
- factor.get_inverse(damping_funcs[3]).to_dense())
- factor_vars = tf_ops.get_collection(tf_ops.GraphKeys.GLOBAL_VARIABLES,
- factor_var_scope)
- factor_tensors = (tf_ops.convert_to_tensor(var) for var in factor_vars)
-
- self.assertEqual(set([inv,
- factor.get_inverse(damping_funcs[2]).to_dense()]),
- set(factor_tensors))
- self.assertEqual(shape, inv.get_shape())
-
- def testRegisterMatpower(self):
- with tf_ops.Graph().as_default():
- random_seed.set_random_seed(200)
- shape = [3, 3]
- factor = DenseSquareMatrixFactorTestingDummy(shape)
- factor_var_scope = 'dummy/a_b_c'
-
- # TODO(b/74201126): Change to using the same func for both once
- # Topohash is in place.
- damping_func_1 = make_damping_func(0.5)
- damping_func_2 = make_damping_func(0.5)
-
- factor.register_matpower(-0.5, damping_func_1)
- factor.register_matpower(2, damping_func_2)
-
- factor.instantiate_inv_variables()
-
- factor_vars = tf_ops.get_collection(tf_ops.GraphKeys.GLOBAL_VARIABLES,
- factor_var_scope)
-
- factor_tensors = (tf_ops.convert_to_tensor(var) for var in factor_vars)
-
- matpower1 = factor.get_matpower(-0.5, damping_func_1).to_dense()
- matpower2 = factor.get_matpower(2, damping_func_2).to_dense()
-
- self.assertEqual(set([matpower1, matpower2]), set(factor_tensors))
-
- self.assertEqual(shape, matpower1.get_shape())
- self.assertEqual(shape, matpower2.get_shape())
-
- def testMakeInverseUpdateOps(self):
- with tf_ops.Graph().as_default():
- random_seed.set_random_seed(200)
- factor = FisherFactorTestingDummy()
-
- self.assertEqual(0, len(factor.make_inverse_update_ops()))
-
- def testMakeInverseUpdateOpsManyInversesEigenDecomp(self):
- with tf_ops.Graph().as_default(), self.cached_session() as sess:
- random_seed.set_random_seed(200)
- cov = np.array([[1., 2.], [3., 4.]])
- factor = DenseSquareMatrixFactorTestingDummy(cov.shape)
- factor._cov = array_ops.constant(cov, dtype=dtypes.float32)
-
- damping_funcs = []
- for i in range(1, ff.EIGENVALUE_DECOMPOSITION_THRESHOLD + 1):
- damping_funcs.append(make_damping_func(1./i))
-
- for i in range(ff.EIGENVALUE_DECOMPOSITION_THRESHOLD):
- factor.register_inverse(damping_funcs[i])
-
- factor.instantiate_inv_variables()
- ops = factor.make_inverse_update_ops()
- self.assertEqual(1, len(ops))
-
- sess.run(tf_variables.global_variables_initializer())
- new_invs = []
- sess.run(ops)
- for i in range(ff.EIGENVALUE_DECOMPOSITION_THRESHOLD):
- # The inverse op will assign the damped inverse of cov to the inv var.
- new_invs.append(
- sess.run(factor.get_inverse(damping_funcs[i]).to_dense()))
-
- # We want to see that the new invs are all different from each other.
- for i in range(len(new_invs)):
- for j in range(i + 1, len(new_invs)):
- # Just check the first element.
- self.assertNotEqual(new_invs[i][0][0], new_invs[j][0][0])
-
- def testMakeInverseUpdateOpsMatPowerEigenDecomp(self):
- with tf_ops.Graph().as_default(), self.cached_session() as sess:
- random_seed.set_random_seed(200)
- cov = np.array([[6., 2.], [2., 4.]])
- factor = DenseSquareMatrixFactorTestingDummy(cov.shape)
- factor._cov = array_ops.constant(cov, dtype=dtypes.float32)
- exp = 2 # NOTE(mattjj): must be int to test with np.linalg.matrix_power
- damping = 0.5
- damping_func = make_damping_func(damping)
-
- factor.register_matpower(exp, damping_func)
- factor.instantiate_inv_variables()
- ops = factor.make_inverse_update_ops()
- self.assertEqual(1, len(ops))
-
- sess.run(tf_variables.global_variables_initializer())
- sess.run(ops[0])
- matpower = sess.run(factor.get_matpower(exp, damping_func).to_dense())
- matpower_np = np.linalg.matrix_power(cov + np.eye(2) * damping, exp)
- self.assertAllClose(matpower, matpower_np)
-
- def testMakeInverseUpdateOpsNoEigenDecomp(self):
- with tf_ops.Graph().as_default(), self.cached_session() as sess:
- random_seed.set_random_seed(200)
- cov = np.array([[5., 2.], [2., 4.]]) # NOTE(mattjj): must be symmetric
- factor = DenseSquareMatrixFactorTestingDummy(cov.shape)
- factor._cov = array_ops.constant(cov, dtype=dtypes.float32)
-
- damping_func = make_damping_func(0)
-
- factor.register_inverse(damping_func)
- factor.instantiate_inv_variables()
- ops = factor.make_inverse_update_ops()
- self.assertEqual(1, len(ops))
-
- sess.run(tf_variables.global_variables_initializer())
- # The inverse op will assign the damped inverse of cov to the inv var.
- old_inv = sess.run(factor.get_inverse(damping_func).to_dense())
- self.assertAllClose(
- sess.run(ff.inverse_initializer(cov.shape, dtypes.float32)), old_inv)
-
- sess.run(ops)
- new_inv = sess.run(factor.get_inverse(damping_func).to_dense())
- self.assertAllClose(new_inv, np.linalg.inv(cov))
-
-
-class FullFactorTest(test.TestCase):
-
- def testFullFactorInit(self):
- with tf_ops.Graph().as_default():
- random_seed.set_random_seed(200)
- tensor = array_ops.ones((2, 3), name='a/b/c')
- factor = ff.FullFactor((tensor,), 32)
- factor.instantiate_cov_variables()
- self.assertEqual([6, 6], factor.get_cov().get_shape().as_list())
-
- def testFullFactorInitFloat64(self):
- with tf_ops.Graph().as_default():
- dtype = dtypes.float64_ref
- random_seed.set_random_seed(200)
- tensor = array_ops.ones((2, 3), dtype=dtype, name='a/b/c')
- factor = ff.FullFactor((tensor,), 32)
- factor.instantiate_cov_variables()
- cov = factor.get_cov()
- self.assertEqual(cov.dtype, dtype)
- self.assertEqual([6, 6], cov.get_shape().as_list())
-
- def testMakeCovarianceUpdateOp(self):
- with tf_ops.Graph().as_default(), self.cached_session() as sess:
- random_seed.set_random_seed(200)
- tensor = array_ops.constant([1., 2.], name='a/b/c')
- factor = ff.FullFactor((tensor,), 2)
- factor.instantiate_cov_variables()
-
- sess.run(tf_variables.global_variables_initializer())
- new_cov = sess.run(factor.make_covariance_update_op(.5))
- self.assertAllClose([[0.75, 0.5], [0.5, 1.5]], new_cov)
-
-
-class NaiveDiagonalFactorTest(test.TestCase):
-
- def testNaiveDiagonalFactorInit(self):
- with tf_ops.Graph().as_default():
- random_seed.set_random_seed(200)
- tensor = array_ops.ones((2, 3), name='a/b/c')
- factor = ff.NaiveDiagonalFactor((tensor,), 32)
- factor.instantiate_cov_variables()
- self.assertEqual([6, 1], factor.get_cov().get_shape().as_list())
-
- def testNaiveDiagonalFactorInitFloat64(self):
- with tf_ops.Graph().as_default():
- dtype = dtypes.float64_ref
- random_seed.set_random_seed(200)
- tensor = array_ops.ones((2, 3), dtype=dtype, name='a/b/c')
- factor = ff.NaiveDiagonalFactor((tensor,), 32)
- factor.instantiate_cov_variables()
- cov = factor.get_cov()
- self.assertEqual(cov.dtype, dtype)
- self.assertEqual([6, 1], cov.get_shape().as_list())
-
- def testMakeCovarianceUpdateOp(self):
- with tf_ops.Graph().as_default(), self.cached_session() as sess:
- random_seed.set_random_seed(200)
- tensor = array_ops.constant([1., 2.], name='a/b/c')
- factor = ff.NaiveDiagonalFactor((tensor,), 2)
- factor.instantiate_cov_variables()
-
- sess.run(tf_variables.global_variables_initializer())
- new_cov = sess.run(factor.make_covariance_update_op(.5))
- self.assertAllClose([[0.75], [1.5]], new_cov)
-
-
-class EmbeddingInputKroneckerFactorTest(test.TestCase):
-
- def testInitialization(self):
- with tf_ops.Graph().as_default():
- input_ids = array_ops.constant([[0], [1], [4]])
- vocab_size = 5
- factor = ff.EmbeddingInputKroneckerFactor((input_ids,), vocab_size)
- factor.instantiate_cov_variables()
- cov = factor.get_cov()
- self.assertEqual(cov.shape.as_list(), [vocab_size])
-
- def testCovarianceUpdateOp(self):
- with tf_ops.Graph().as_default():
- input_ids = array_ops.constant([[0], [1], [4]])
- vocab_size = 5
- factor = ff.EmbeddingInputKroneckerFactor((input_ids,), vocab_size)
- factor.instantiate_cov_variables()
- cov_update_op = factor.make_covariance_update_op(0.0)
-
- with self.cached_session() as sess:
- sess.run(tf_variables.global_variables_initializer())
- new_cov = sess.run(cov_update_op)
- self.assertAllClose(np.array([1., 1., 0., 0., 1.]) / 3., new_cov)
-
-
-class ConvDiagonalFactorTest(test.TestCase):
-
- def setUp(self):
- self.batch_size = 10
- self.height = self.width = 32
- self.in_channels = 3
- self.out_channels = 1
- self.kernel_height = self.kernel_width = 3
- self.strides = [1, 2, 2, 1]
- self.data_format = 'NHWC'
- self.padding = 'SAME'
- self.kernel_shape = [
- self.kernel_height, self.kernel_width, self.in_channels,
- self.out_channels
- ]
-
- def testInit(self):
- with tf_ops.Graph().as_default():
- inputs = random_ops.random_uniform(
- [self.batch_size, self.height, self.width, self.in_channels])
- outputs_grads = [
- random_ops.random_uniform([
- self.batch_size, self.height // self.strides[1],
- self.width // self.strides[2], self.out_channels
- ]) for _ in range(3)
- ]
-
- factor = ff.ConvDiagonalFactor(
- (inputs,),
- (outputs_grads,),
- self.kernel_shape,
- self.strides,
- self.padding,
- data_format=self.data_format)
- factor.instantiate_cov_variables()
-
- # Ensure covariance matrix's shape makes sense.
- self.assertEqual([
- self.kernel_height * self.kernel_width * self.in_channels,
- self.out_channels
- ],
- factor.get_cov().shape.as_list())
-
- def testMakeCovarianceUpdateOp(self):
- with tf_ops.Graph().as_default():
- # Construct all arguments such that convolution kernel is applied in
- # exactly one spatial location.
- inputs = np.random.randn(
- 1, # batch_size
- self.kernel_height,
- self.kernel_width,
- self.in_channels) # in_channels
- outputs_grad = np.random.randn(
- 1, # batch_size
- 1, # output_height
- 1, # output_width
- self.out_channels)
-
- factor = ff.ConvDiagonalFactor(
- (constant_op.constant(inputs),),
- ((constant_op.constant(outputs_grad),),),
- self.kernel_shape,
- strides=[1, 1, 1, 1],
- padding='VALID')
- factor.instantiate_cov_variables()
-
- # Completely forget initial value on first update.
- cov_update_op = factor.make_covariance_update_op(0.0)
-
- # Ensure new covariance value is same as outer-product of inputs/outputs
- # vectorized, squared.
- with self.cached_session() as sess:
- sess.run(tf_variables.global_variables_initializer())
- cov = sess.run(cov_update_op)
- expected_cov = np.outer(inputs.flatten(), outputs_grad.flatten())**2
- self.assertAllClose(expected_cov, cov)
-
- def testHasBias(self):
- with tf_ops.Graph().as_default():
- inputs = random_ops.random_uniform(
- [self.batch_size, self.height, self.width, self.in_channels])
- outputs_grads = [
- random_ops.random_uniform([
- self.batch_size, self.height // self.strides[1],
- self.width // self.strides[2], self.out_channels
- ]) for _ in range(3)
- ]
-
- factor = ff.ConvDiagonalFactor(
- (inputs,),
- (outputs_grads,),
- self.kernel_shape,
- self.strides,
- self.padding,
- data_format=self.data_format,
- has_bias=True)
- factor.instantiate_cov_variables()
-
- # Ensure shape accounts for bias.
- self.assertEqual([
- self.kernel_height * self.kernel_width * self.in_channels + 1,
- self.out_channels
- ],
- factor.get_cov().shape.as_list())
-
- # Ensure update op doesn't crash.
- cov_update_op = factor.make_covariance_update_op(0.0)
- with self.cached_session() as sess:
- sess.run(tf_variables.global_variables_initializer())
- sess.run(cov_update_op)
-
-
-class FullyConnectedKroneckerFactorTest(test.TestCase):
-
- def _testFullyConnectedKroneckerFactorInit(self,
- has_bias,
- final_shape,
- dtype=dtypes.float32_ref):
- with tf_ops.Graph().as_default():
- random_seed.set_random_seed(200)
- tensor = array_ops.ones((2, 3), dtype=dtype, name='a/b/c')
- factor = ff.FullyConnectedKroneckerFactor(((tensor,),), has_bias=has_bias)
- factor.instantiate_cov_variables()
- cov = factor.get_cov()
- self.assertEqual(cov.dtype, dtype)
- self.assertEqual(final_shape, cov.get_shape().as_list())
-
- def testFullyConnectedKroneckerFactorInitNoBias(self):
- for dtype in (dtypes.float32_ref, dtypes.float64_ref):
- self._testFullyConnectedKroneckerFactorInit(False, [3, 3], dtype=dtype)
-
- def testFullyConnectedKroneckerFactorInitWithBias(self):
- for dtype in (dtypes.float32_ref, dtypes.float64_ref):
- self._testFullyConnectedKroneckerFactorInit(True, [4, 4], dtype=dtype)
-
- def testMakeCovarianceUpdateOpWithBias(self):
- with tf_ops.Graph().as_default(), self.cached_session() as sess:
- random_seed.set_random_seed(200)
- tensor = array_ops.constant([[1., 2.], [3., 4.]], name='a/b/c')
- factor = ff.FullyConnectedKroneckerFactor(((tensor,),), has_bias=True)
- factor.instantiate_cov_variables()
-
- sess.run(tf_variables.global_variables_initializer())
- new_cov = sess.run(factor.make_covariance_update_op(.5))
- self.assertAllClose([[3, 3.5, 1], [3.5, 5.5, 1.5], [1, 1.5, 1]], new_cov)
-
- def testMakeCovarianceUpdateOpNoBias(self):
- with tf_ops.Graph().as_default(), self.cached_session() as sess:
- random_seed.set_random_seed(200)
- tensor = array_ops.constant([[1., 2.], [3., 4.]], name='a/b/c')
- factor = ff.FullyConnectedKroneckerFactor(((tensor,),))
- factor.instantiate_cov_variables()
-
- sess.run(tf_variables.global_variables_initializer())
- new_cov = sess.run(factor.make_covariance_update_op(.5))
- self.assertAllClose([[3, 3.5], [3.5, 5.5]], new_cov)
-
-
-class ConvFactorTestCase(test.TestCase):
-
- def assertMatrixRank(self, rank, matrix, atol=1e-5):
- assert rank <= matrix.shape[0], 'Rank cannot be larger than matrix size.'
- eigvals = np.linalg.eigvals(matrix)
- nnz_eigvals = np.sum(eigvals > atol)
- self.assertEqual(
- rank,
- nnz_eigvals,
- msg=('Found %d of %d expected non-zero eigenvalues: %s.' %
- (nnz_eigvals, rank, eigvals)))
-
-
-class ConvInputKroneckerFactorTest(ConvFactorTestCase):
-
- def test3DConvolution(self):
- with tf_ops.Graph().as_default():
- batch_size = 1
- width = 3
- in_channels = 3**3
- out_channels = 4
-
- factor = ff.ConvInputKroneckerFactor(
- inputs=(random_ops.random_uniform(
- (batch_size, width, width, width, in_channels), seed=0),),
- filter_shape=(width, width, width, in_channels, out_channels),
- padding='SAME',
- strides=(2, 2, 2),
- extract_patches_fn='extract_convolution_patches',
- has_bias=False)
- factor.instantiate_cov_variables()
-
- # Ensure shape of covariance matches input size of filter.
- input_size = in_channels * (width**3)
- self.assertEqual([input_size, input_size],
- factor.get_cov().shape.as_list())
-
- # Ensure cov_update_op doesn't crash.
- with self.cached_session() as sess:
- sess.run(tf_variables.global_variables_initializer())
- sess.run(factor.make_covariance_update_op(0.0))
- cov = sess.run(factor.get_cov())
-
- # Cov should be rank-8, as the filter will be applied at each corner of
- # the 4-D cube.
- self.assertMatrixRank(8, cov)
-
- def testPointwiseConv2d(self):
- with tf_ops.Graph().as_default():
- batch_size = 1
- width = 3
- in_channels = 3**2
- out_channels = 4
-
- factor = ff.ConvInputKroneckerFactor(
- inputs=(random_ops.random_uniform(
- (batch_size, width, width, in_channels), seed=0),),
- filter_shape=(1, 1, in_channels, out_channels),
- padding='SAME',
- strides=(1, 1, 1, 1),
- extract_patches_fn='extract_pointwise_conv2d_patches',
- has_bias=False)
- factor.instantiate_cov_variables()
-
- # Ensure shape of covariance matches input size of filter.
- self.assertEqual([in_channels, in_channels],
- factor.get_cov().shape.as_list())
-
- # Ensure cov_update_op doesn't crash.
- with self.cached_session() as sess:
- sess.run(tf_variables.global_variables_initializer())
- sess.run(factor.make_covariance_update_op(0.0))
- cov = sess.run(factor.get_cov())
-
- # Cov should be rank-9, as the filter will be applied at each location.
- self.assertMatrixRank(9, cov)
-
- def testStrides(self):
- with tf_ops.Graph().as_default():
- batch_size = 1
- width = 3
- in_channels = 3**2
- out_channels = 4
-
- factor = ff.ConvInputKroneckerFactor(
- inputs=(random_ops.random_uniform(
- (batch_size, width, width, in_channels), seed=0),),
- filter_shape=(1, 1, in_channels, out_channels),
- padding='SAME',
- strides=(1, 2, 1, 1),
- extract_patches_fn='extract_image_patches',
- has_bias=False)
- factor.instantiate_cov_variables()
-
- with self.cached_session() as sess:
- sess.run(tf_variables.global_variables_initializer())
- sess.run(factor.make_covariance_update_op(0.0))
- cov = sess.run(factor.get_cov())
-
- # Cov should be the sum of 3 * 2 = 6 outer products.
- self.assertMatrixRank(6, cov)
-
- def testDilationRate(self):
- with tf_ops.Graph().as_default():
- batch_size = 1
- width = 3
- in_channels = 2
- out_channels = 4
-
- factor = ff.ConvInputKroneckerFactor(
- inputs=(random_ops.random_uniform(
- (batch_size, width, width, in_channels), seed=0),),
- filter_shape=(3, 3, in_channels, out_channels),
- padding='SAME',
- extract_patches_fn='extract_image_patches',
- strides=(1, 1, 1, 1),
- dilation_rate=(1, width, width, 1),
- has_bias=False)
- factor.instantiate_cov_variables()
-
- with self.cached_session() as sess:
- sess.run(tf_variables.global_variables_initializer())
- sess.run(factor.make_covariance_update_op(0.0))
- cov = sess.run(factor.get_cov())
-
- # Cov should be rank = in_channels, as only the center of the filter
- # receives non-zero input for each input channel.
- self.assertMatrixRank(in_channels, cov)
-
- def testConvInputKroneckerFactorInitNoBias(self):
- with tf_ops.Graph().as_default():
- tensor = array_ops.ones((64, 1, 2, 3), name='a/b/c')
- factor = ff.ConvInputKroneckerFactor(
- inputs=(tensor,),
- filter_shape=(1, 2, 3, 4),
- padding='SAME',
- has_bias=False)
- factor.instantiate_cov_variables()
- self.assertEqual([1 * 2 * 3, 1 * 2 * 3],
- factor.get_cov().get_shape().as_list())
-
- def testConvInputKroneckerFactorInit(self):
- with tf_ops.Graph().as_default():
- tensor = array_ops.ones((64, 1, 2, 3), name='a/b/c')
- factor = ff.ConvInputKroneckerFactor(
- (tensor,), filter_shape=(1, 2, 3, 4), padding='SAME', has_bias=True)
- factor.instantiate_cov_variables()
- self.assertEqual([1 * 2 * 3 + 1, 1 * 2 * 3 + 1],
- factor.get_cov().get_shape().as_list())
-
- def testConvInputKroneckerFactorInitFloat64(self):
- with tf_ops.Graph().as_default():
- dtype = dtypes.float64_ref
- tensor = array_ops.ones((64, 1, 2, 3), name='a/b/c', dtype=dtypes.float64)
- factor = ff.ConvInputKroneckerFactor(
- (tensor,), filter_shape=(1, 2, 3, 4), padding='SAME', has_bias=True)
- factor.instantiate_cov_variables()
- cov = factor.get_cov()
- self.assertEqual(cov.dtype, dtype)
- self.assertEqual([1 * 2 * 3 + 1, 1 * 2 * 3 + 1],
- cov.get_shape().as_list())
-
- def testMakeCovarianceUpdateOpWithBias(self):
- with tf_ops.Graph().as_default(), self.cached_session() as sess:
- input_shape = (2, 1, 1, 1)
- tensor = array_ops.constant(
- np.arange(1, 1 + np.prod(input_shape)).reshape(input_shape).astype(
- np.float32))
- factor = ff.ConvInputKroneckerFactor(
- (tensor,), filter_shape=(1, 1, 1, 1), padding='SAME', has_bias=True)
- factor.instantiate_cov_variables()
-
- sess.run(tf_variables.global_variables_initializer())
- new_cov = sess.run(factor.make_covariance_update_op(0.))
- self.assertAllClose(
- [
- [(1. + 4.) / 2., (1. + 2.) / 2.], #
- [(1. + 2.) / 2., (1. + 1.) / 2.]
- ], #
- new_cov)
-
- def testMakeCovarianceUpdateOpNoBias(self):
- with tf_ops.Graph().as_default(), self.cached_session() as sess:
- input_shape = (2, 1, 1, 1)
- tensor = array_ops.constant(
- np.arange(1, 1 + np.prod(input_shape)).reshape(input_shape).astype(
- np.float32))
- factor = ff.ConvInputKroneckerFactor(
- (tensor,), filter_shape=(1, 1, 1, 1), padding='SAME')
- factor.instantiate_cov_variables()
-
- sess.run(tf_variables.global_variables_initializer())
- new_cov = sess.run(factor.make_covariance_update_op(0.))
- self.assertAllClose([[(1. + 4.) / 2.]], new_cov)
-
- def testSubSample(self):
- with tf_ops.Graph().as_default():
- patches_1 = array_ops.constant(1, shape=(10, 2))
- patches_2 = array_ops.constant(1, shape=(10, 8))
- patches_3 = array_ops.constant(1, shape=(3, 3))
- patches_1_sub = ff._subsample_for_cov_computation(patches_1)
- patches_2_sub = ff._subsample_for_cov_computation(patches_2)
- patches_3_sub = ff._subsample_for_cov_computation(patches_3)
- patches_1_sub_batch_size = patches_1_sub.shape.as_list()[0]
- patches_2_sub_batch_size = patches_2_sub.shape.as_list()[0]
- patches_3_sub_batch_size = patches_3_sub.shape.as_list()[0]
- self.assertEqual(2, patches_1_sub_batch_size)
- self.assertEqual(8, patches_2_sub_batch_size)
- self.assertEqual(3, patches_3_sub_batch_size)
-
-
-class ConvOutputKroneckerFactorTest(ConvFactorTestCase):
-
- def test3DConvolution(self):
- with tf_ops.Graph().as_default():
- batch_size = 1
- width = 3
- out_channels = width**3
-
- factor = ff.ConvOutputKroneckerFactor(outputs_grads=([
- random_ops.random_uniform(
- (batch_size, width, width, width, out_channels), seed=0)
- ],))
- factor.instantiate_cov_variables()
-
- with self.cached_session() as sess:
- sess.run(tf_variables.global_variables_initializer())
- sess.run(factor.make_covariance_update_op(0.0))
- cov = sess.run(factor.get_cov())
-
- # Cov should be rank 3^3, as each spatial position donates a rank-1
- # update.
- self.assertMatrixRank(width**3, cov)
-
- def testConvOutputKroneckerFactorInit(self):
- with tf_ops.Graph().as_default():
- random_seed.set_random_seed(200)
- tensor = array_ops.ones((2, 3, 4, 5), name='a/b/c')
- factor = ff.ConvOutputKroneckerFactor(((tensor,),))
- factor.instantiate_cov_variables()
- self.assertEqual([5, 5], factor.get_cov().get_shape().as_list())
-
- def testConvOutputKroneckerFactorInitFloat64(self):
- with tf_ops.Graph().as_default():
- dtype = dtypes.float64_ref
- random_seed.set_random_seed(200)
- tensor = array_ops.ones((2, 3, 4, 5), dtype=dtype, name='a/b/c')
- factor = ff.ConvOutputKroneckerFactor(((tensor,),))
- factor.instantiate_cov_variables()
- cov = factor.get_cov()
- self.assertEqual(cov.dtype, dtype)
- self.assertEqual([5, 5], cov.get_shape().as_list())
-
- def testMakeCovarianceUpdateOp(self):
- with tf_ops.Graph().as_default(), self.cached_session() as sess:
- random_seed.set_random_seed(200)
- tensor = np.arange(1, 17).reshape(2, 2, 2, 2).astype(np.float32)
- factor = ff.ConvOutputKroneckerFactor(((array_ops.constant(tensor),),))
- factor.instantiate_cov_variables()
-
- sess.run(tf_variables.global_variables_initializer())
- new_cov = sess.run(factor.make_covariance_update_op(.5))
- self.assertAllClose([[43, 46.5], [46.5, 51.5]], new_cov)
-
-
-class FullyConnectedMultiKFTest(test.TestCase):
-
- def testFullyConnectedMultiKFInit(self):
- with tf_ops.Graph().as_default():
- random_seed.set_random_seed(200)
- tensor = array_ops.ones((2, 3), name='a/b/c')
- factor = ff.FullyConnectedMultiKF(((tensor,),), has_bias=False)
- factor.instantiate_cov_variables()
- self.assertEqual([3, 3], factor.get_cov().get_shape().as_list())
-
- def testFullyConnectedMultiKFInitFloat64(self):
- with tf_ops.Graph().as_default():
- dtype = dtypes.float64_ref
- random_seed.set_random_seed(200)
- tensor = array_ops.ones((2, 3), dtype=dtype, name='a/b/c')
- factor = ff.FullyConnectedMultiKF(((tensor,),), has_bias=False)
- factor.instantiate_cov_variables()
- cov = factor.get_cov()
- self.assertEqual(cov.dtype, dtype)
- self.assertEqual([3, 3], cov.get_shape().as_list())
-
- def testMakeCovarianceUpdateOpWithBias(self):
- with tf_ops.Graph().as_default(), self.cached_session() as sess:
- random_seed.set_random_seed(200)
- tensor = array_ops.constant([[1., 2.], [3., 4.]], name='a/b/c')
- factor = ff.FullyConnectedMultiKF(((tensor,),), has_bias=True)
- factor.instantiate_cov_variables()
-
- sess.run(tf_variables.global_variables_initializer())
- new_cov = sess.run(factor.make_covariance_update_op(.5))
- self.assertAllClose([[3, 3.5, 1], [3.5, 5.5, 1.5], [1, 1.5, 1]], new_cov)
-
- def testMakeCovarianceUpdateOpNoBias(self):
- with tf_ops.Graph().as_default(), self.cached_session() as sess:
- random_seed.set_random_seed(200)
- tensor = array_ops.constant([[1., 2.], [3., 4.]], name='a/b/c')
- factor = ff.FullyConnectedMultiKF(((tensor,),))
- factor.instantiate_cov_variables()
-
- sess.run(tf_variables.global_variables_initializer())
- new_cov = sess.run(factor.make_covariance_update_op(.5))
- self.assertAllClose([[3, 3.5], [3.5, 5.5]], new_cov)
-
-
-if __name__ == '__main__':
- test.main()
diff --git a/tensorflow/contrib/kfac/python/kernel_tests/layer_collection_test.py b/tensorflow/contrib/kfac/python/kernel_tests/layer_collection_test.py
deleted file mode 100644
index 586fcd4c3c..0000000000
--- a/tensorflow/contrib/kfac/python/kernel_tests/layer_collection_test.py
+++ /dev/null
@@ -1,597 +0,0 @@
-# Copyright 2017 The TensorFlow Authors. All Rights Reserved.
-#
-# Licensed under the Apache License, Version 2.0 (the "License");
-# you may not use this file except in compliance with the License.
-# You may obtain a copy of the License at
-#
-# http://www.apache.org/licenses/LICENSE-2.0
-#
-# Unless required by applicable law or agreed to in writing, software
-# distributed under the License is distributed on an "AS IS" BASIS,
-# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-# See the License for the specific language governing permissions and
-# limitations under the License.
-# ==============================================================================
-"""Tests for tf.contrib.kfac.layer_collection."""
-
-from __future__ import absolute_import
-from __future__ import division
-from __future__ import print_function
-
-from tensorflow.contrib.kfac.python.ops import fisher_blocks
-from tensorflow.contrib.kfac.python.ops import fisher_factors
-from tensorflow.contrib.kfac.python.ops import layer_collection
-from tensorflow.python.framework import dtypes
-from tensorflow.python.framework import ops
-from tensorflow.python.framework import random_seed
-from tensorflow.python.ops import array_ops
-from tensorflow.python.ops import linalg_ops
-from tensorflow.python.ops import math_ops
-from tensorflow.python.ops import random_ops
-from tensorflow.python.ops import variable_scope
-from tensorflow.python.platform import test
-
-
-class MockFisherBlock(object):
- """A fake FisherBlock."""
-
- num_registered_towers = 2
-
- def __init__(self, name='MockFisherBlock'):
- self.name = name
-
- def __eq__(self, other):
- return isinstance(other, MockFisherBlock) and other.name == self.name
-
- def __hash__(self):
- return hash(self.name)
-
-
-class LayerParametersDictTest(test.TestCase):
-
- def testSetItem(self):
- """Ensure insertion, contains, retrieval works for supported key types."""
- with ops.Graph().as_default():
- lp_dict = layer_collection.LayerParametersDict()
-
- x = array_ops.constant(0)
- y0 = array_ops.constant(0)
- y1 = array_ops.constant(0)
- z0 = array_ops.constant(0)
- z1 = array_ops.constant(0)
- keys = [x, (y0, y1), [z0, z1]]
- for key in keys:
- lp_dict[key] = key
-
- for key in keys:
- self.assertTrue(key in lp_dict)
- self.assertEqual(lp_dict[key], key)
-
- def testSetItemOverlap(self):
- """Ensure insertion fails if key overlaps with existing key."""
- with ops.Graph().as_default():
- lp_dict = layer_collection.LayerParametersDict()
-
- x = array_ops.constant(0)
- y = array_ops.constant(0)
- lp_dict[x] = 'value'
-
- with self.assertRaises(ValueError):
- lp_dict[(x, y)] = 'value'
-
- # Ensure 'y' wasn't inserted.
- self.assertTrue(x in lp_dict)
- self.assertFalse(y in lp_dict)
-
-
-class LayerCollectionTest(test.TestCase):
-
- def testLayerCollectionInit(self):
- lc = layer_collection.LayerCollection()
- self.assertEqual(0, len(lc.get_blocks()))
- self.assertEqual(0, len(lc.get_factors()))
- self.assertFalse(lc.losses)
-
- def testRegisterBlocks(self):
- with ops.Graph().as_default():
- random_seed.set_random_seed(200)
- lc = layer_collection.LayerCollection()
- lc.register_fully_connected(
- array_ops.constant(1), array_ops.constant(2), array_ops.constant(3))
- lc.register_fully_connected(
- array_ops.constant(1),
- array_ops.constant(2),
- array_ops.constant(3),
- approx=layer_collection.APPROX_DIAGONAL_NAME)
- lc.register_conv2d(
- params=array_ops.ones((2, 3, 4, 5)),
- strides=[1, 1, 1, 1],
- padding='SAME',
- inputs=array_ops.ones((1, 2, 3, 4)),
- outputs=array_ops.ones((1, 1, 1, 5)))
- lc.register_conv2d(
- params=array_ops.ones((2, 3, 4, 5)),
- strides=[1, 1, 1, 1],
- padding='SAME',
- inputs=array_ops.ones((1, 2, 3, 4)),
- outputs=array_ops.ones((1, 1, 1, 5)),
- approx=layer_collection.APPROX_DIAGONAL_NAME)
- lc.register_separable_conv2d(
- depthwise_params=array_ops.ones((3, 3, 1, 2)),
- pointwise_params=array_ops.ones((1, 1, 2, 4)),
- inputs=array_ops.ones((32, 5, 5, 1)),
- depthwise_outputs=array_ops.ones((32, 5, 5, 2)),
- pointwise_outputs=array_ops.ones((32, 5, 5, 4)),
- strides=[1, 1, 1, 1],
- padding='SAME')
- lc.register_convolution(
- params=array_ops.ones((3, 3, 1, 8)),
- inputs=array_ops.ones((32, 5, 5, 1)),
- outputs=array_ops.ones((32, 5, 5, 8)),
- padding='SAME')
- lc.register_generic(
- array_ops.constant(5), 16, approx=layer_collection.APPROX_FULL_NAME)
- lc.register_generic(
- array_ops.constant(6),
- 16,
- approx=layer_collection.APPROX_DIAGONAL_NAME)
- lc.register_fully_connected_multi(
- array_ops.constant(1),
- (array_ops.constant(2), array_ops.constant(3)),
- (array_ops.constant(4), array_ops.constant(5)))
- lc.register_conv2d_multi(
- params=array_ops.ones((2, 3, 4, 5)),
- strides=[1, 1, 1, 1],
- padding='SAME',
- inputs=(array_ops.ones((1, 2, 3, 4)), array_ops.ones((5, 6, 7, 8))),
- outputs=(array_ops.ones((1, 1, 1, 5)), array_ops.ones((2, 2, 2, 10))))
- lc.register_embedding_multi(
- array_ops.constant((1,)),
- (array_ops.constant(2), array_ops.constant(3)),
- (array_ops.constant(4), array_ops.constant(5)))
-
- self.assertEqual(12, len(lc.get_blocks()))
-
- def testRegisterBlocksMultipleRegistrations(self):
- with ops.Graph().as_default():
- random_seed.set_random_seed(200)
- lc = layer_collection.LayerCollection()
- key = array_ops.constant(1)
- lc.register_fully_connected(key, array_ops.constant(2),
- array_ops.constant(3))
- with self.assertRaises(ValueError) as cm:
- lc.register_generic(key, 16)
- self.assertIn('already in LayerCollection', str(cm.exception))
-
- def testRegisterSingleParamNotRegistered(self):
- x = variable_scope.get_variable('x', initializer=array_ops.constant(1,))
- lc = layer_collection.LayerCollection()
- lc.fisher_blocks = {
- variable_scope.get_variable('y', initializer=array_ops.constant(1,)):
- '1'
- }
- lc.register_block(x, 'foo')
-
- def testShouldRegisterSingleParamRegistered(self):
- x = variable_scope.get_variable('x', initializer=array_ops.constant(1,))
- lc = layer_collection.LayerCollection()
- lc.fisher_blocks = {x: '1'}
- with self.assertRaises(ValueError) as cm:
- lc.register_block(x, 'foo')
- self.assertIn('already in LayerCollection', str(cm.exception))
-
- def testRegisterSingleParamRegisteredInTuple(self):
- x = variable_scope.get_variable('x', initializer=array_ops.constant(1,))
- y = variable_scope.get_variable('y', initializer=array_ops.constant(1,))
- lc = layer_collection.LayerCollection()
- lc.fisher_blocks = {(x, y): '1'}
- with self.assertRaises(ValueError) as cm:
- lc.register_block(x, 'foo')
- self.assertIn('was already registered', str(cm.exception))
-
- def testRegisterTupleParamNotRegistered(self):
- x = variable_scope.get_variable('x', initializer=array_ops.constant(1,))
- y = variable_scope.get_variable('y', initializer=array_ops.constant(1,))
- lc = layer_collection.LayerCollection()
- lc.fisher_blocks = {
- variable_scope.get_variable('z', initializer=array_ops.constant(1,)):
- '1'
- }
-
- lc.register_block((x, y), 'foo')
- self.assertEqual(set(['1', 'foo']), set(lc.get_blocks()))
-
- def testRegisterTupleParamRegistered(self):
- x = variable_scope.get_variable('x', initializer=array_ops.constant(1,))
- y = variable_scope.get_variable('y', initializer=array_ops.constant(1,))
- lc = layer_collection.LayerCollection()
- lc.fisher_blocks = {(x, y): '1'}
-
- with self.assertRaises(ValueError) as cm:
- lc.register_block((x, y), 'foo')
- self.assertIn('already in LayerCollection', str(cm.exception))
-
- def testRegisterTupleParamRegisteredInSuperset(self):
- x = variable_scope.get_variable('x', initializer=array_ops.constant(1,))
- y = variable_scope.get_variable('y', initializer=array_ops.constant(1,))
- z = variable_scope.get_variable('z', initializer=array_ops.constant(1,))
- lc = layer_collection.LayerCollection()
- lc.fisher_blocks = {(x, y, z): '1'}
-
- with self.assertRaises(ValueError) as cm:
- lc.register_block((x, y), 'foo')
- self.assertIn('was already registered', str(cm.exception))
-
- def testRegisterTupleParamSomeRegistered(self):
- x = variable_scope.get_variable('x', initializer=array_ops.constant(1,))
- y = variable_scope.get_variable('y', initializer=array_ops.constant(1,))
- z = variable_scope.get_variable('z', initializer=array_ops.constant(1,))
- lc = layer_collection.LayerCollection()
- lc.fisher_blocks = {x: MockFisherBlock('1'), z: MockFisherBlock('2')}
-
- with self.assertRaises(ValueError) as cm:
- lc.register_block((x, y), MockFisherBlock('foo'))
- self.assertIn('was already registered', str(cm.exception))
-
- def testRegisterTupleVarSomeRegisteredInOtherTuples(self):
- x = variable_scope.get_variable('x', initializer=array_ops.constant(1,))
- y = variable_scope.get_variable('y', initializer=array_ops.constant(1,))
- z = variable_scope.get_variable('z', initializer=array_ops.constant(1,))
- w = variable_scope.get_variable('w', initializer=array_ops.constant(1,))
- lc = layer_collection.LayerCollection()
- lc.fisher_blocks = {(x, z): '1', (z, w): '2'}
-
- with self.assertRaises(ValueError) as cm:
- lc.register_block((x, y), 'foo')
- self.assertIn('was already registered', str(cm.exception))
-
- def testRegisterCategoricalPredictiveDistribution(self):
- with ops.Graph().as_default(), self.cached_session() as sess:
- random_seed.set_random_seed(200)
- logits = linalg_ops.eye(2)
-
- lc = layer_collection.LayerCollection()
- lc.register_categorical_predictive_distribution(logits, seed=200)
- single_loss = sess.run(lc.total_sampled_loss())
-
- lc2 = layer_collection.LayerCollection()
- lc2.register_categorical_predictive_distribution(logits, seed=200)
- lc2.register_categorical_predictive_distribution(logits, seed=200)
- double_loss = sess.run(lc2.total_sampled_loss())
- self.assertAlmostEqual(2 * single_loss, double_loss)
-
- def testLossFunctionByName(self):
- """Ensure loss functions can be identified by name."""
- with ops.Graph().as_default():
- logits = linalg_ops.eye(2)
- lc = layer_collection.LayerCollection()
-
- # Create a new loss function by name.
- lc.register_categorical_predictive_distribution(logits, name='loss1')
- self.assertEqual(1, len(lc.towers_by_loss))
-
- # Add logits to same loss function.
- lc.register_categorical_predictive_distribution(
- logits, name='loss1', reuse=True)
- self.assertEqual(1, len(lc.towers_by_loss))
-
- # Add another new loss function.
- lc.register_categorical_predictive_distribution(logits, name='loss2')
- self.assertEqual(2, len(lc.towers_by_loss))
-
- def testLossFunctionWithoutName(self):
- """Ensure loss functions get unique names if 'name' not specified."""
- with ops.Graph().as_default():
- logits = linalg_ops.eye(2)
- lc = layer_collection.LayerCollection()
-
- # Create a new loss function with default names.
- lc.register_categorical_predictive_distribution(logits)
- lc.register_categorical_predictive_distribution(logits)
- self.assertEqual(2, len(lc.losses))
-
- def testCategoricalPredictiveDistributionMultipleMinibatches(self):
- """Ensure multiple minibatches are registered."""
- with ops.Graph().as_default():
- batch_size = 3
- output_size = 2
- logits = array_ops.zeros([batch_size, output_size])
- targets = array_ops.ones([batch_size], dtype=dtypes.int32)
- lc = layer_collection.LayerCollection()
-
- # Create a new loss function.
- lc.register_categorical_predictive_distribution(
- logits, targets=targets, name='loss1')
-
- # Can add when reuse=True
- lc.register_categorical_predictive_distribution(
- logits, targets=targets, name='loss1', reuse=True)
-
- # Can add when reuse=VARIABLE_SCOPE and reuse=True there.
- with variable_scope.variable_scope(
- variable_scope.get_variable_scope(), reuse=True):
- lc.register_categorical_predictive_distribution(
- logits,
- targets=targets,
- name='loss1',
- reuse=layer_collection.VARIABLE_SCOPE)
-
- # Can't add when reuse=False
- with self.assertRaises(KeyError):
- lc.register_categorical_predictive_distribution(
- logits, targets=targets, name='loss1', reuse=False)
-
- # Can't add when reuse=VARIABLE_SCOPE and reuse=False there.
- with self.assertRaises(KeyError):
- lc.register_categorical_predictive_distribution(
- logits,
- targets=targets,
- name='loss1',
- reuse=layer_collection.VARIABLE_SCOPE)
-
- self.assertEqual(len(lc.towers_by_loss), 1)
- # Three successful registrations.
- self.assertEqual(len(lc.towers_by_loss[0]), 3)
-
- def testRegisterCategoricalPredictiveDistributionBatchSize1(self):
- with ops.Graph().as_default():
- random_seed.set_random_seed(200)
- logits = random_ops.random_normal((1, 2))
- lc = layer_collection.LayerCollection()
-
- lc.register_categorical_predictive_distribution(logits, seed=200)
-
- def testRegisterCategoricalPredictiveDistributionSpecifiedTargets(self):
- with ops.Graph().as_default(), self.cached_session() as sess:
- random_seed.set_random_seed(200)
- logits = array_ops.constant([[1., 2.], [3., 4.]], dtype=dtypes.float32)
- lc = layer_collection.LayerCollection()
- targets = array_ops.constant([0, 1], dtype=dtypes.int32)
-
- lc.register_categorical_predictive_distribution(logits, targets=targets)
- single_loss = sess.run(lc.total_loss())
- self.assertAlmostEqual(1.6265233, single_loss)
-
- def testRegisterNormalPredictiveDistribution(self):
- with ops.Graph().as_default(), self.cached_session() as sess:
- random_seed.set_random_seed(200)
- predictions = array_ops.constant(
- [[1., 2.], [3., 4]], dtype=dtypes.float32)
-
- lc = layer_collection.LayerCollection()
- lc.register_normal_predictive_distribution(predictions, 1., seed=200)
- single_loss = sess.run(lc.total_sampled_loss())
-
- lc2 = layer_collection.LayerCollection()
- lc2.register_normal_predictive_distribution(predictions, 1., seed=200)
- lc2.register_normal_predictive_distribution(predictions, 1., seed=200)
- double_loss = sess.run(lc2.total_sampled_loss())
-
- self.assertAlmostEqual(2 * single_loss, double_loss)
-
- def testRegisterNormalPredictiveDistributionSpecifiedTargets(self):
- with ops.Graph().as_default(), self.cached_session() as sess:
- random_seed.set_random_seed(200)
- predictions = array_ops.constant(
- [[1., 2.], [3., 4.]], dtype=dtypes.float32)
- lc = layer_collection.LayerCollection()
- targets = array_ops.constant([[3., 1.], [4., 2.]], dtype=dtypes.float32)
-
- lc.register_normal_predictive_distribution(
- predictions, 2.**2, targets=targets)
- single_loss = sess.run(lc.total_loss())
- self.assertAlmostEqual(7.6983433, single_loss)
-
- def ensureLayerReuseWorks(self, register_fn):
- """Ensure the 'reuse' keyword argument function as intended.
-
- Args:
- register_fn: function for registering a layer. Arguments are
- layer_collection, reuse, and approx.
- """
- # Fails on second if reuse=False.
- lc = layer_collection.LayerCollection()
- register_fn(lc)
- with self.assertRaises(ValueError):
- register_fn(lc, reuse=False)
-
- # Succeeds on second if reuse=True.
- lc = layer_collection.LayerCollection()
- register_fn(lc)
- register_fn(lc, reuse=True)
-
- # Fails on second if reuse=VARIABLE_SCOPE and no variable reuse.
- lc = layer_collection.LayerCollection()
- register_fn(lc)
- with self.assertRaises(ValueError):
- register_fn(lc, reuse=layer_collection.VARIABLE_SCOPE)
-
- # Succeeds on second if reuse=VARIABLE_SCOPE and variable reuse.
- lc = layer_collection.LayerCollection()
- register_fn(lc)
- with variable_scope.variable_scope(
- variable_scope.get_variable_scope(), reuse=True):
- register_fn(lc, reuse=layer_collection.VARIABLE_SCOPE)
-
- # Fails if block type changes.
- lc = layer_collection.LayerCollection()
- register_fn(lc, approx=layer_collection.APPROX_KRONECKER_NAME)
- with self.assertRaises(ValueError):
- register_fn(lc, approx=layer_collection.APPROX_DIAGONAL_NAME, reuse=True)
-
- # Fails if reuse requested but no FisherBlock exists.
- lc = layer_collection.LayerCollection()
- with self.assertRaises(KeyError):
- register_fn(lc, reuse=True)
-
- def testRegisterFullyConnectedReuse(self):
- """Ensure the 'reuse' works with register_fully_connected."""
- with ops.Graph().as_default():
- inputs = array_ops.ones([2, 10])
- outputs = array_ops.zeros([2, 5])
- params = (
- variable_scope.get_variable('w', [10, 5]), #
- variable_scope.get_variable('b', [5]))
-
- def register_fn(lc, **kwargs):
- lc.register_fully_connected(
- params=params, inputs=inputs, outputs=outputs, **kwargs)
-
- self.ensureLayerReuseWorks(register_fn)
-
- def testRegisterConv2dReuse(self):
- """Ensure the 'reuse' works with register_conv2d."""
- with ops.Graph().as_default():
- inputs = array_ops.ones([2, 5, 5, 10])
- outputs = array_ops.zeros([2, 5, 5, 3])
- params = (
- variable_scope.get_variable('w', [1, 1, 10, 3]), #
- variable_scope.get_variable('b', [3]))
-
- def register_fn(lc, **kwargs):
- lc.register_conv2d(
- params=params,
- strides=[1, 1, 1, 1],
- padding='SAME',
- inputs=inputs,
- outputs=outputs,
- **kwargs)
-
- self.ensureLayerReuseWorks(register_fn)
-
- def testReuseWithInvalidRegistration(self):
- """Invalid registrations shouldn't overwrite existing blocks."""
- with ops.Graph().as_default():
- inputs = array_ops.ones([2, 5, 5, 10])
- outputs = array_ops.zeros([2, 5, 5, 3])
- w = variable_scope.get_variable('w', [1, 1, 10, 3])
- b = variable_scope.get_variable('b', [3])
- lc = layer_collection.LayerCollection()
- lc.register_fully_connected(w, inputs, outputs)
- self.assertEqual(lc.fisher_blocks[w].num_registered_towers, 1)
- with self.assertRaises(KeyError):
- lc.register_fully_connected((w, b), inputs, outputs, reuse=True)
- self.assertNotIn((w, b), lc.fisher_blocks)
- self.assertEqual(lc.fisher_blocks[w].num_registered_towers, 1)
- lc.register_fully_connected(w, inputs, outputs, reuse=True)
- self.assertEqual(lc.fisher_blocks[w].num_registered_towers, 2)
-
- def testMakeOrGetFactor(self):
- with ops.Graph().as_default():
- random_seed.set_random_seed(200)
- lc = layer_collection.LayerCollection()
- key = array_ops.constant(1)
- lc.make_or_get_factor(fisher_factors.FullFactor, ((key,), 16))
- lc.make_or_get_factor(fisher_factors.FullFactor, ((key,), 16))
- lc.make_or_get_factor(fisher_factors.FullFactor,
- ((array_ops.constant(2),), 16))
-
- self.assertEqual(2, len(lc.get_factors()))
- variables = ops.get_collection(ops.GraphKeys.GLOBAL_VARIABLES)
- self.assertTrue(
- all([var.name.startswith('LayerCollection') for var in variables]))
-
- def testMakeOrGetFactorCustomScope(self):
- with ops.Graph().as_default():
- random_seed.set_random_seed(200)
- scope = 'Foo'
- lc = layer_collection.LayerCollection(name=scope)
- key = array_ops.constant(1)
- lc.make_or_get_factor(fisher_factors.FullFactor, ((key,), 16))
- lc.make_or_get_factor(fisher_factors.FullFactor, ((key,), 16))
- lc.make_or_get_factor(fisher_factors.FullFactor,
- ((array_ops.constant(2),), 16))
-
- self.assertEqual(2, len(lc.get_factors()))
- variables = ops.get_collection(ops.GraphKeys.GLOBAL_VARIABLES)
- self.assertTrue(all([var.name.startswith(scope) for var in variables]))
-
- def testIdentifyLinkedParametersSomeRegisteredInOtherTuples(self):
- x = variable_scope.get_variable('x', shape=())
- y = variable_scope.get_variable('y', shape=())
- z = variable_scope.get_variable('z', shape=())
- lc = layer_collection.LayerCollection()
- lc.define_linked_parameters((x, y))
-
- with self.assertRaises(ValueError):
- lc.define_linked_parameters((x, z))
-
- def testIdentifySubsetPreviouslyRegisteredTensor(self):
- x = variable_scope.get_variable('x', shape=())
- y = variable_scope.get_variable('y', shape=())
- lc = layer_collection.LayerCollection()
- lc.define_linked_parameters((x, y))
-
- with self.assertRaises(ValueError):
- lc.define_linked_parameters(x)
-
- def testSpecifyApproximation(self):
- w_0 = variable_scope.get_variable('w_0', [10, 10])
- w_1 = variable_scope.get_variable('w_1', [10, 10])
-
- b_0 = variable_scope.get_variable('b_0', [10])
- b_1 = variable_scope.get_variable('b_1', [10])
-
- x_0 = array_ops.placeholder(dtypes.float32, shape=(32, 10))
- x_1 = array_ops.placeholder(dtypes.float32, shape=(32, 10))
-
- pre_bias_0 = math_ops.matmul(x_0, w_0)
- pre_bias_1 = math_ops.matmul(x_1, w_1)
-
- # Build the fully connected layers in the graph.
- pre_bias_0 + b_0 # pylint: disable=pointless-statement
- pre_bias_1 + b_1 # pylint: disable=pointless-statement
-
- lc = layer_collection.LayerCollection()
- lc.define_linked_parameters(
- w_0, approximation=layer_collection.APPROX_DIAGONAL_NAME)
- lc.define_linked_parameters(
- w_1, approximation=layer_collection.APPROX_DIAGONAL_NAME)
- lc.define_linked_parameters(
- b_0, approximation=layer_collection.APPROX_FULL_NAME)
- lc.define_linked_parameters(
- b_1, approximation=layer_collection.APPROX_FULL_NAME)
-
- lc.register_fully_connected(w_0, x_0, pre_bias_0)
- lc.register_fully_connected(
- w_1, x_1, pre_bias_1, approx=layer_collection.APPROX_KRONECKER_NAME)
- self.assertIsInstance(lc.fisher_blocks[w_0],
- fisher_blocks.FullyConnectedDiagonalFB)
- self.assertIsInstance(lc.fisher_blocks[w_1],
- fisher_blocks.FullyConnectedKFACBasicFB)
-
- lc.register_generic(b_0, batch_size=1)
- lc.register_generic(
- b_1, batch_size=1, approx=layer_collection.APPROX_DIAGONAL_NAME)
- self.assertIsInstance(lc.fisher_blocks[b_0], fisher_blocks.FullFB)
- self.assertIsInstance(lc.fisher_blocks[b_1], fisher_blocks.NaiveDiagonalFB)
-
- def testDefaultLayerCollection(self):
- with ops.Graph().as_default():
- # Can't get default if there isn't one set.
- with self.assertRaises(ValueError):
- layer_collection.get_default_layer_collection()
-
- # Can't set default twice.
- lc = layer_collection.LayerCollection()
- layer_collection.set_default_layer_collection(lc)
- with self.assertRaises(ValueError):
- layer_collection.set_default_layer_collection(lc)
-
- # Same as one set.
- self.assertTrue(lc is layer_collection.get_default_layer_collection())
-
- # Can set to None.
- layer_collection.set_default_layer_collection(None)
- with self.assertRaises(ValueError):
- layer_collection.get_default_layer_collection()
-
- # as_default() is the same as setting/clearing.
- with lc.as_default():
- self.assertTrue(lc is layer_collection.get_default_layer_collection())
- with self.assertRaises(ValueError):
- layer_collection.get_default_layer_collection()
-
-
-if __name__ == '__main__':
- test.main()
diff --git a/tensorflow/contrib/kfac/python/kernel_tests/loss_functions_test.py b/tensorflow/contrib/kfac/python/kernel_tests/loss_functions_test.py
deleted file mode 100644
index f424e02360..0000000000
--- a/tensorflow/contrib/kfac/python/kernel_tests/loss_functions_test.py
+++ /dev/null
@@ -1,190 +0,0 @@
-# Copyright 2017 The TensorFlow Authors. All Rights Reserved.
-#
-# Licensed under the Apache License, Version 2.0 (the "License");
-# you may not use this file except in compliance with the License.
-# You may obtain a copy of the License at
-#
-# http://www.apache.org/licenses/LICENSE-2.0
-#
-# Unless required by applicable law or agreed to in writing, software
-# distributed under the License is distributed on an "AS IS" BASIS,
-# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-# See the License for the specific language governing permissions and
-# limitations under the License.
-# ==============================================================================
-"""Tests for tf.contrib.kfac.loss_functions."""
-
-from __future__ import absolute_import
-from __future__ import division
-from __future__ import print_function
-
-import numpy as np
-
-from tensorflow.contrib.kfac.python.ops import loss_functions
-from tensorflow.python.framework import constant_op
-from tensorflow.python.framework import ops
-from tensorflow.python.ops import array_ops
-from tensorflow.python.platform import test
-
-
-class InsertSliceInZerosTest(test.TestCase):
-
- def testBadShape(self):
- bad_shaped_ones = array_ops.ones(shape=[1, 3]) # n.b. shape[1] != 1
- with self.assertRaises(ValueError):
- loss_functions.insert_slice_in_zeros(bad_shaped_ones, 1, 42, 17)
-
- def test3d(self):
- input_tensor = constant_op.constant([[[1, 2]], [[3, 4]]])
- expected_output_array = [[[1, 2], [0, 0]], [[3, 4], [0, 0]]]
- op = loss_functions.insert_slice_in_zeros(input_tensor, 1, 2, 0)
- with self.cached_session() as sess:
- actual_output_array = sess.run(op)
- self.assertAllEqual(expected_output_array, actual_output_array)
-
-
-class CategoricalLogitsNegativeLogProbLossTest(test.TestCase):
-
- def testSample(self):
- """Ensure samples can be drawn."""
- with ops.Graph().as_default(), self.cached_session() as sess:
- logits = np.asarray([
- [0., 0., 0.], #
- [1., -1., 0.]
- ]).astype(np.float32)
- loss = loss_functions.CategoricalLogitsNegativeLogProbLoss(
- array_ops.constant(logits))
- sample = loss.sample(42)
- sample = sess.run(sample)
- self.assertEqual(sample.shape, (2,))
-
- def testEvaluateOnTargets(self):
- """Ensure log probability can be evaluated correctly."""
- with ops.Graph().as_default(), self.cached_session() as sess:
- logits = np.asarray([
- [0., 0., 0.], #
- [1., -1., 0.]
- ]).astype(np.float32)
- targets = np.asarray([2, 1]).astype(np.int32)
- loss = loss_functions.CategoricalLogitsNegativeLogProbLoss(
- array_ops.constant(logits), targets=array_ops.constant(targets))
- neg_log_prob = loss.evaluate()
- neg_log_prob = sess.run(neg_log_prob)
-
- # Calculate explicit log probability of targets.
- probs = np.exp(logits) / np.sum(np.exp(logits), axis=1, keepdims=True)
- log_probs = np.log([
- probs[0, targets[0]], #
- probs[1, targets[1]]
- ])
- expected_log_prob = np.sum(log_probs)
-
- self.assertAllClose(neg_log_prob, -expected_log_prob)
-
- def testEvaluateOnSample(self):
- """Ensure log probability of a sample can be drawn."""
- with ops.Graph().as_default(), self.cached_session() as sess:
- logits = np.asarray([
- [0., 0., 0.], #
- [1., -1., 0.]
- ]).astype(np.float32)
- loss = loss_functions.CategoricalLogitsNegativeLogProbLoss(
- array_ops.constant(logits))
- neg_log_prob = loss.evaluate_on_sample(42)
-
- # Simply ensure this doesn't crash. As the output is random, it's
- # difficult to say if the output is correct or not...
- neg_log_prob = sess.run(neg_log_prob)
-
- def testMultiplyFisherSingleVector(self):
- with ops.Graph().as_default(), self.cached_session() as sess:
- logits = np.array([1., 2., 3.])
- loss = loss_functions.CategoricalLogitsNegativeLogProbLoss(logits)
-
- # the LossFunction.multiply_fisher docstring only says it supports the
- # case where the vector is the same shape as the input natural parameters
- # (i.e. the logits here), but here we also test leading dimensions
- vector = np.array([1., 2., 3.])
- vectors = [vector, vector.reshape(1, -1), np.stack([vector] * 4)]
-
- probs = np.exp(logits - np.logaddexp.reduce(logits))
- fisher = np.diag(probs) - np.outer(probs, probs)
-
- for vector in vectors:
- result = loss.multiply_fisher(vector)
- expected_result = np.dot(vector, fisher)
- self.assertAllClose(expected_result, sess.run(result))
-
- def testMultiplyFisherBatch(self):
- with ops.Graph().as_default(), self.cached_session() as sess:
- logits = np.array([[1., 2., 3.], [4., 6., 8.]])
- loss = loss_functions.CategoricalLogitsNegativeLogProbLoss(logits)
-
- vector = np.array([[1., 2., 3.], [5., 3., 1.]])
-
- na = np.newaxis
- probs = np.exp(logits - np.logaddexp.reduce(logits, axis=-1,
- keepdims=True))
- fishers = probs[..., na] * np.eye(3) - probs[..., na] * probs[..., na, :]
-
- result = loss.multiply_fisher(vector)
- expected_result = np.matmul(vector[..., na, :], fishers)[..., 0, :]
- self.assertEqual(sess.run(result).shape, logits.shape)
- self.assertAllClose(expected_result, sess.run(result))
-
-
-class OnehotCategoricalLogitsNegativeLogProbLossTest(test.TestCase):
-
- def testSample(self):
- """Ensure samples can be drawn."""
- with ops.Graph().as_default(), self.cached_session() as sess:
- logits = np.asarray([
- [0., 0., 0.], #
- [1., -1., 0.]
- ]).astype(np.float32)
- loss = loss_functions.OnehotCategoricalLogitsNegativeLogProbLoss(
- array_ops.constant(logits))
- sample = loss.sample(42)
- sample = sess.run(sample)
- self.assertEqual(sample.shape, (2, 3))
-
- def testEvaluateOnTargets(self):
- """Ensure log probability can be evaluated correctly."""
- with ops.Graph().as_default(), self.cached_session() as sess:
- logits = np.asarray([
- [0., 0., 0.], #
- [1., -1., 0.]
- ]).astype(np.float32)
- targets = np.asarray([2, 1]).astype(np.int32)
- loss = loss_functions.OnehotCategoricalLogitsNegativeLogProbLoss(
- array_ops.constant(logits), targets=array_ops.one_hot(targets, 3))
- neg_log_prob = loss.evaluate()
- neg_log_prob = sess.run(neg_log_prob)
-
- # Calculate explicit log probability of targets.
- probs = np.exp(logits) / np.sum(np.exp(logits), axis=1, keepdims=True)
- log_probs = np.log([
- probs[0, targets[0]], #
- probs[1, targets[1]]
- ])
- expected_log_prob = np.sum(log_probs)
-
- self.assertAllClose(neg_log_prob, -expected_log_prob)
-
- def testEvaluateOnSample(self):
- """Ensure log probability of a sample can be drawn."""
- with ops.Graph().as_default(), self.cached_session() as sess:
- logits = np.asarray([
- [0., 0., 0.], #
- [1., -1., 0.]
- ]).astype(np.float32)
- loss = loss_functions.OnehotCategoricalLogitsNegativeLogProbLoss(
- array_ops.constant(logits))
- neg_log_prob = loss.evaluate_on_sample(42)
-
- # Simply ensure this doesn't crash. As the output is random, it's
- # difficult to say if the output is correct or not...
- neg_log_prob = sess.run(neg_log_prob)
-
-if __name__ == "__main__":
- test.main()
diff --git a/tensorflow/contrib/kfac/python/kernel_tests/op_queue_test.py b/tensorflow/contrib/kfac/python/kernel_tests/op_queue_test.py
deleted file mode 100644
index 4fae4374e1..0000000000
--- a/tensorflow/contrib/kfac/python/kernel_tests/op_queue_test.py
+++ /dev/null
@@ -1,50 +0,0 @@
-# Copyright 2017 The TensorFlow Authors. All Rights Reserved.
-#
-# Licensed under the Apache License, Version 2.0 (the "License");
-# you may not use this file except in compliance with the License.
-# You may obtain a copy of the License at
-#
-# http://www.apache.org/licenses/LICENSE-2.0
-#
-# Unless required by applicable law or agreed to in writing, software
-# distributed under the License is distributed on an "AS IS" BASIS,
-# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-# See the License for the specific language governing permissions and
-# limitations under the License.
-# ==============================================================================
-"""Tests for tf.contrib.kfac.op_queue."""
-
-from __future__ import absolute_import
-from __future__ import division
-from __future__ import print_function
-
-from tensorflow.contrib.kfac.python.ops import op_queue
-from tensorflow.python.framework import ops as tf_ops
-from tensorflow.python.ops import math_ops
-from tensorflow.python.platform import test
-
-
-class OpQueueTest(test.TestCase):
-
- def testNextOp(self):
- """Ensures all ops get selected eventually."""
- with tf_ops.Graph().as_default():
- ops = [
- math_ops.add(1, 2),
- math_ops.subtract(1, 2),
- math_ops.reduce_mean([1, 2]),
- ]
- queue = op_queue.OpQueue(ops, seed=0)
-
- with self.cached_session() as sess:
- # Ensure every inv update op gets selected.
- selected_ops = set([queue.next_op(sess) for _ in ops])
- self.assertEqual(set(ops), set(selected_ops))
-
- # Ensure additional calls don't create any new ops.
- selected_ops.add(queue.next_op(sess))
- self.assertEqual(set(ops), set(selected_ops))
-
-
-if __name__ == "__main__":
- test.main()
diff --git a/tensorflow/contrib/kfac/python/kernel_tests/optimizer_test.py b/tensorflow/contrib/kfac/python/kernel_tests/optimizer_test.py
deleted file mode 100644
index 0b0de12ce6..0000000000
--- a/tensorflow/contrib/kfac/python/kernel_tests/optimizer_test.py
+++ /dev/null
@@ -1,219 +0,0 @@
-# Copyright 2017 The TensorFlow Authors. All Rights Reserved.
-#
-# Licensed under the Apache License, Version 2.0 (the "License");
-# you may not use this file except in compliance with the License.
-# You may obtain a copy of the License at
-#
-# http://www.apache.org/licenses/LICENSE-2.0
-#
-# Unless required by applicable law or agreed to in writing, software
-# distributed under the License is distributed on an "AS IS" BASIS,
-# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-# See the License for the specific language governing permissions and
-# limitations under the License.
-# ==============================================================================
-"""Tests for tf.contrib.kfac.optimizer."""
-
-from __future__ import absolute_import
-from __future__ import division
-from __future__ import print_function
-
-import numpy as np
-
-from tensorflow.contrib.kfac.python.ops import fisher_factors as ff
-from tensorflow.contrib.kfac.python.ops import layer_collection as lc
-from tensorflow.contrib.kfac.python.ops import optimizer
-from tensorflow.python.framework import ops
-from tensorflow.python.ops import array_ops
-from tensorflow.python.ops import init_ops
-from tensorflow.python.ops import math_ops
-from tensorflow.python.ops import nn
-from tensorflow.python.ops import variable_scope
-from tensorflow.python.ops import variables as tf_variables
-from tensorflow.python.platform import test
-
-
-# We need to set these constants since the numerical values used in the tests
-# were chosen when these used to be the defaults.
-ff.set_global_constants(init_covariances_at_zero=False,
- zero_debias=False,
- init_inverses_at_zero=False)
-
-
-def dummy_layer_collection():
- lcoll = lc.LayerCollection()
- dummy = array_ops.constant([1., 2.])
- lcoll.register_categorical_predictive_distribution(logits=dummy)
- return lcoll
-
-
-class OptimizerTest(test.TestCase):
-
- def testOptimizerInitInvalidMomentumRegistration(self):
- with self.assertRaises(ValueError):
- optimizer.KfacOptimizer(
- 0.1, 0.2, 0.3, lc.LayerCollection(), momentum_type='foo')
-
- def testOptimizerInit(self):
- with ops.Graph().as_default():
- layer_collection = lc.LayerCollection()
-
- inputs = array_ops.ones((2, 1)) * 2
- weights_val = np.ones((1, 1), dtype=np.float32) * 3.
- weights = variable_scope.get_variable(
- 'w', initializer=array_ops.constant(weights_val))
- bias = variable_scope.get_variable(
- 'b', initializer=init_ops.zeros_initializer(), shape=(1, 1))
- output = math_ops.matmul(inputs, weights) + bias
-
- layer_collection.register_fully_connected((weights, bias), inputs, output)
-
- logits = math_ops.tanh(output)
- targets = array_ops.constant([[0.], [1.]])
- output = math_ops.reduce_mean(
- nn.softmax_cross_entropy_with_logits(logits=logits, labels=targets))
-
- layer_collection.register_categorical_predictive_distribution(logits)
-
- optimizer.KfacOptimizer(
- 0.1,
- 0.2,
- 0.3,
- layer_collection,
- momentum=0.5,
- momentum_type='regular')
-
- def testSquaredFisherNorm(self):
- with ops.Graph().as_default(), self.cached_session() as sess:
- grads_and_vars = [(array_ops.constant([[1., 2.], [3., 4.]]), None),
- (array_ops.constant([[2., 3.], [4., 5.]]), None)]
- pgrads_and_vars = [(array_ops.constant([[3., 4.], [5., 6.]]), None),
- (array_ops.constant([[7., 8.], [9., 10.]]), None)]
- opt = optimizer.KfacOptimizer(0.1, 0.2, 0.3, dummy_layer_collection())
- sq_norm = opt._squared_fisher_norm(grads_and_vars, pgrads_and_vars)
- self.assertAlmostEqual(174., sess.run(sq_norm), places=5)
-
- def testUpdateClipCoeff(self):
- with ops.Graph().as_default(), self.cached_session() as sess:
- grads_and_vars = [(array_ops.constant([[1., 2.], [3., 4.]]), None),
- (array_ops.constant([[2., 3.], [4., 5.]]), None)]
- pgrads_and_vars = [(array_ops.constant([[3., 4.], [5., 6.]]), None),
- (array_ops.constant([[7., 8.], [9., 10.]]), None)]
- lrate = 0.1
-
- # Note: without rescaling, the squared Fisher norm of the update
- # is 1.74
-
- # If the update already satisfies the norm constraint, there should
- # be no rescaling.
- opt = optimizer.KfacOptimizer(
- lrate, 0.2, 0.3, dummy_layer_collection(), norm_constraint=10.)
- coeff = opt._update_clip_coeff(grads_and_vars, pgrads_and_vars)
- self.assertAlmostEqual(1., sess.run(coeff), places=5)
-
- # If the update violates the constraint, it should be rescaled to
- # be on the constraint boundary.
- opt = optimizer.KfacOptimizer(
- lrate, 0.2, 0.3, dummy_layer_collection(), norm_constraint=0.5)
- coeff = opt._update_clip_coeff(grads_and_vars, pgrads_and_vars)
- sq_norm_pgrad = opt._squared_fisher_norm(grads_and_vars, pgrads_and_vars)
- sq_norm_update = lrate**2 * coeff**2 * sq_norm_pgrad
- self.assertAlmostEqual(0.5, sess.run(sq_norm_update), places=5)
-
- def testComputeUpdateStepsRegular(self):
- # TODO(olganw): implement this.
- pass
-
- def testComputeUpdateStepsAdam(self):
- # TODO(olganw): implement this.
- pass
-
- def testUpdateVelocities(self):
- with ops.Graph().as_default(), self.cached_session() as sess:
- layers = lc.LayerCollection()
- layers.register_categorical_predictive_distribution(
- array_ops.constant([1.0]))
- opt = optimizer.KfacOptimizer(
- 0.1, 0.2, 0.3, layers, momentum=0.5, momentum_type='regular')
- x = variable_scope.get_variable('x', initializer=array_ops.ones((2, 2)))
- y = variable_scope.get_variable(
- 'y', initializer=array_ops.ones((2, 2)) * 2)
- vec1 = array_ops.ones((2, 2)) * 3
- vec2 = array_ops.ones((2, 2)) * 4
-
- model_vars = ops.get_collection(ops.GraphKeys.GLOBAL_VARIABLES)
- update_op = opt._update_velocities([(vec1, x), (vec2, y)], 0.5)
- opt_vars = [
- v for v in ops.get_collection(ops.GraphKeys.GLOBAL_VARIABLES)
- if v not in model_vars
- ]
-
- sess.run(tf_variables.global_variables_initializer())
- old_opt_vars = sess.run(opt_vars)
-
- # Optimizer vars start out at 0.
- for opt_var in old_opt_vars:
- self.assertAllEqual(sess.run(array_ops.zeros_like(opt_var)), opt_var)
-
- sess.run(update_op)
- new_opt_vars = sess.run(opt_vars)
- # After one update, the velocities are equal to the vectors.
- for vec, opt_var in zip([vec1, vec2], new_opt_vars):
- self.assertAllEqual(sess.run(vec), opt_var)
-
- sess.run(update_op)
- final_opt_vars = sess.run(opt_vars)
- for first, second in zip(new_opt_vars, final_opt_vars):
- self.assertFalse(np.equal(first, second).all())
-
- def testApplyGradients(self):
- with ops.Graph().as_default(), self.cached_session() as sess:
- layer_collection = lc.LayerCollection()
-
- inputs = array_ops.ones((2, 1)) * 2
- weights_val = np.ones((1, 1), dtype=np.float32) * 3.
- weights = variable_scope.get_variable(
- 'w', initializer=array_ops.constant(weights_val))
- bias = variable_scope.get_variable(
- 'b', initializer=init_ops.zeros_initializer(), shape=(1, 1))
- output = math_ops.matmul(inputs, weights) + bias
-
- layer_collection.register_fully_connected((weights, bias), inputs, output)
-
- logits = math_ops.tanh(output)
- targets = array_ops.constant([[0.], [1.]])
- output = math_ops.reduce_mean(
- nn.softmax_cross_entropy_with_logits(logits=logits, labels=targets))
-
- layer_collection.register_categorical_predictive_distribution(logits)
-
- opt = optimizer.KfacOptimizer(
- 0.1,
- 0.2,
- 0.3,
- layer_collection,
- momentum=0.5,
- momentum_type='regular')
- (cov_update_thunks,
- inv_update_thunks) = opt.make_vars_and_create_op_thunks()
- cov_update_ops = tuple(thunk() for thunk in cov_update_thunks)
- inv_update_ops = tuple(thunk() for thunk in inv_update_thunks)
-
- grads_and_vars = opt.compute_gradients(output, [weights, bias])
- all_vars = [grad_and_var[1] for grad_and_var in grads_and_vars]
-
- op = opt.apply_gradients(grads_and_vars)
-
- sess.run(tf_variables.global_variables_initializer())
- old_vars = sess.run(all_vars)
- sess.run(cov_update_ops)
- sess.run(inv_update_ops)
- sess.run(op)
- new_vars = sess.run(all_vars)
-
- for old_var, new_var in zip(old_vars, new_vars):
- self.assertNotEqual(old_var, new_var)
-
-
-if __name__ == '__main__':
- test.main()
diff --git a/tensorflow/contrib/kfac/python/kernel_tests/utils_test.py b/tensorflow/contrib/kfac/python/kernel_tests/utils_test.py
deleted file mode 100644
index 7df79a3c7f..0000000000
--- a/tensorflow/contrib/kfac/python/kernel_tests/utils_test.py
+++ /dev/null
@@ -1,410 +0,0 @@
-# Copyright 2017 The TensorFlow Authors. All Rights Reserved.
-#
-# Licensed under the Apache License, Version 2.0 (the "License");
-# you may not use this file except in compliance with the License.
-# You may obtain a copy of the License at
-#
-# http://www.apache.org/licenses/LICENSE-2.0
-#
-# Unless required by applicable law or agreed to in writing, software
-# distributed under the License is distributed on an "AS IS" BASIS,
-# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-# See the License for the specific language governing permissions and
-# limitations under the License.
-# ==============================================================================
-"""Tests for tf.contrib.kfac.utils."""
-
-from __future__ import absolute_import
-from __future__ import division
-from __future__ import print_function
-
-import numpy as np
-import numpy.random as npr
-
-from tensorflow.contrib.kfac.python.ops import utils
-from tensorflow.contrib.tpu.python.tpu import tpu_function
-from tensorflow.python.framework import dtypes
-from tensorflow.python.framework import ops
-from tensorflow.python.framework import random_seed
-from tensorflow.python.ops import array_ops
-from tensorflow.python.ops import linalg_ops
-from tensorflow.python.ops import math_ops
-from tensorflow.python.ops import nn_ops
-from tensorflow.python.ops import random_ops
-from tensorflow.python.ops import variable_scope
-from tensorflow.python.ops import variables
-from tensorflow.python.platform import test
-
-
-class SequenceDictTest(test.TestCase):
-
- def testSequenceDictInit(self):
- seq_dict = utils.SequenceDict()
- self.assertFalse(seq_dict._dict)
-
- def testSequenceDictInitWithIterable(self):
- reg_dict = {'a': 'foo', 'b': 'bar'}
- itr = zip(reg_dict.keys(), reg_dict.values())
- seq_dict = utils.SequenceDict(itr)
- self.assertEqual(reg_dict, seq_dict._dict)
-
- def testGetItemSingleKey(self):
- seq_dict = utils.SequenceDict({'a': 'foo', 'b': 'bar'})
- self.assertEqual('foo', seq_dict['a'])
-
- def testGetItemMultipleKeys(self):
- seq_dict = utils.SequenceDict({'a': 'foo', 'b': 'bar'})
- self.assertEqual(['foo', 'bar'], seq_dict[('a', 'b')])
-
- def testSetItemSingleKey(self):
- seq_dict = utils.SequenceDict()
- seq_dict['a'] = 'foo'
- self.assertEqual([('a', 'foo')], seq_dict.items())
-
- def testSetItemMultipleKeys(self):
- seq_dict = utils.SequenceDict()
- keys = ('a', 'b', 'c')
- values = ('foo', 'bar', 'baz')
- seq_dict[keys] = values
- self.assertItemsEqual(list(zip(keys, values)), seq_dict.items())
-
-
-class SubGraphTest(test.TestCase):
-
- def testBasicGraph(self):
- a = array_ops.constant([[1., 2.], [3., 4.]])
- b = array_ops.constant([[5., 6.], [7., 8.]])
- c = a + b
- d = a * b
- sub_graph = utils.SubGraph((c,))
- self.assertTrue(sub_graph.is_member(a))
- self.assertTrue(sub_graph.is_member(b))
- self.assertTrue(sub_graph.is_member(c))
- self.assertFalse(sub_graph.is_member(d))
-
- def testRepeatedAdds(self):
- a = array_ops.constant([[1., 2.], [3., 4.]])
- b = array_ops.constant([[5., 6.], [7., 8.]])
- c = a + b + a # note that a appears twice in this graph
- sub_graph = utils.SubGraph((c,))
- self.assertTrue(sub_graph.is_member(a))
- self.assertTrue(sub_graph.is_member(b))
- self.assertTrue(sub_graph.is_member(c))
-
- def testFilterList(self):
- a = array_ops.constant([[1., 2.], [3., 4.]])
- b = array_ops.constant([[5., 6.], [7., 8.]])
- c = a + b
- d = a * b
- sub_graph = utils.SubGraph((c,))
- input_list = [b, d]
- filtered_list = sub_graph.filter_list(input_list)
- self.assertEqual(filtered_list, [b])
-
- def testVariableUses(self):
- with ops.Graph().as_default():
- var = variable_scope.get_variable('var', shape=[10, 10])
- resource_var = variable_scope.get_variable(
- 'resource_var', shape=[10, 10], use_resource=True)
- x = array_ops.zeros([3, 10])
- z0 = math_ops.matmul(x, var) + math_ops.matmul(x, var)
- z1 = math_ops.matmul(x, resource_var)
- sub_graph = utils.SubGraph((z0, z1))
- self.assertEqual(2, sub_graph.variable_uses(var))
- self.assertEqual(1, sub_graph.variable_uses(resource_var))
-
-
-class UtilsTest(test.TestCase):
-
- def _fully_connected_layer_params(self):
- weights_part = array_ops.constant([[1., 2.], [4., 3.]])
- bias_part = array_ops.constant([1., 2.])
- return (weights_part, bias_part)
-
- def _conv_layer_params(self):
- weights_shape = 2, 2, 3, 4
- biases_shape = weights_shape[-1:]
- weights = array_ops.constant(npr.RandomState(0).randn(*weights_shape))
- biases = array_ops.constant(npr.RandomState(1).randn(*biases_shape))
- return (weights, biases)
-
- def testFullyConnectedLayerParamsTupleToMat2d(self):
- with ops.Graph().as_default(), self.cached_session() as sess:
- random_seed.set_random_seed(200)
- layer_params = self._fully_connected_layer_params()
- output = utils.layer_params_to_mat2d(layer_params)
- self.assertListEqual([3, 2], output.get_shape().as_list())
- self.assertAllClose(
- sess.run(output), np.array([[1., 2.], [4., 3.], [1., 2.]]))
-
- def testFullyConnectedLayerParamsTensorToMat2d(self):
- with ops.Graph().as_default(), self.cached_session() as sess:
- random_seed.set_random_seed(200)
- layer_params = self._fully_connected_layer_params()
- output = utils.layer_params_to_mat2d(layer_params[0])
- self.assertListEqual([2, 2], output.get_shape().as_list())
- self.assertAllClose(sess.run(output), np.array([[1., 2.], [4., 3.]]))
-
- def testConvLayerParamsTupleToMat2d(self):
- with ops.Graph().as_default():
- random_seed.set_random_seed(200)
- layer_params = self._conv_layer_params()
- output = utils.layer_params_to_mat2d(layer_params)
- self.assertListEqual([2 * 2 * 3 + 1, 4], output.get_shape().as_list())
-
- def testKron(self):
- with ops.Graph().as_default(), self.cached_session() as sess:
- mat1 = np.array([[1., 2.], [3., 4.]])
- mat2 = np.array([[5., 6.], [7., 8.]])
- mat1_tf = array_ops.constant(mat1)
- mat2_tf = array_ops.constant(mat2)
- ans_tf = sess.run(utils.kronecker_product(mat1_tf, mat2_tf))
- ans_np = np.kron(mat1, mat2)
- self.assertAllClose(ans_tf, ans_np)
-
- def testMat2dToFullyConnectedLayerParamsTuple(self):
- with ops.Graph().as_default(), self.cached_session() as sess:
- random_seed.set_random_seed(200)
- vector_template = self._fully_connected_layer_params()
- mat2d = array_ops.constant([[5., 4.], [3., 2.], [1., 0.]])
-
- output = sess.run(utils.mat2d_to_layer_params(vector_template, mat2d))
-
- self.assertIsInstance(output, tuple)
- self.assertEqual(len(output), 2)
- a, b = output
- self.assertAllClose(a, np.array([[5., 4.], [3., 2.]]))
- self.assertAllClose(b, np.array([1., 0.]))
-
- def testMat2dToFullyConnectedLayerParamsTensor(self):
- with ops.Graph().as_default(), self.cached_session() as sess:
- random_seed.set_random_seed(200)
- vector_template = self._fully_connected_layer_params()[0]
- mat2d = array_ops.constant([[5., 4.], [3., 2.]])
-
- output = sess.run(utils.mat2d_to_layer_params(vector_template, mat2d))
-
- self.assertAllClose(output, np.array([[5., 4.], [3., 2.]]))
-
- def testTensorsToColumn(self):
- with ops.Graph().as_default(), self.cached_session() as sess:
- random_seed.set_random_seed(200)
-
- vector = array_ops.constant(np.array([[0., 1.], [2., 3.]]))
- output = utils.tensors_to_column(vector)
- self.assertListEqual([4, 1], output.get_shape().as_list())
- self.assertAllClose(sess.run(output), np.array([0., 1., 2., 3.])[:, None])
-
- vector = self._fully_connected_layer_params()
- output = utils.tensors_to_column(vector)
- self.assertListEqual([6, 1], output.get_shape().as_list())
- self.assertAllClose(
- sess.run(output), np.array([1., 2., 4., 3., 1., 2.])[:, None])
-
- vector = list(vector)
- vector.append(array_ops.constant([[6.], [7.], [8.], [9.]]))
-
- output = utils.tensors_to_column(vector)
- self.assertListEqual([10, 1], output.get_shape().as_list())
- self.assertAllClose(
- sess.run(output),
- np.array([1., 2., 4., 3., 1., 2., 6., 7., 8., 9.])[:, None])
-
- def testColumnToTensors(self):
- with ops.Graph().as_default(), self.cached_session() as sess:
- random_seed.set_random_seed(200)
-
- vector_template = array_ops.constant(np.array([[0., 1.], [2., 3.]]))
- colvec = array_ops.constant(np.arange(4.)[:, None])
- output = sess.run(utils.column_to_tensors(vector_template, colvec))
- self.assertAllClose(output, np.array([[0., 1.], [2., 3.]]))
-
- vector_template = self._fully_connected_layer_params()
- colvec = array_ops.constant(np.arange(6.)[:, None])
- output = sess.run(utils.column_to_tensors(vector_template, colvec))
-
- self.assertIsInstance(output, tuple)
- self.assertEqual(len(output), 2)
- a, b = output
- self.assertAllClose(a, np.array([[0., 1.], [2., 3.]]))
- self.assertAllClose(b, np.array([4., 5.]))
-
- vector_template = list(vector_template)
- vector_template.append(array_ops.constant([[6.], [7.], [8.], [9.]]))
- colvec = array_ops.constant(np.arange(10.)[:, None])
- output = sess.run(utils.column_to_tensors(vector_template, colvec))
- self.assertIsInstance(output, tuple)
- self.assertEqual(len(output), 3)
- a, b, c = output
- self.assertAllClose(a, np.array([[0., 1.], [2., 3.]]))
- self.assertAllClose(b, np.array([4., 5.]))
- self.assertAllClose(c, np.array([[6.], [7.], [8.], [9.]]))
-
- def testPosDefInvCholesky(self):
- with ops.Graph().as_default(), self.cached_session() as sess:
- random_seed.set_random_seed(200)
- npr.seed(0)
- square = lambda x: np.dot(x, x.T)
-
- size = 3
- x = square(npr.randn(size, size))
- damp = 0.1
- identity = linalg_ops.eye(size, dtype=dtypes.float64)
-
- tf_inv = utils.posdef_inv_cholesky(array_ops.constant(x), identity, damp)
- np_inv = np.linalg.inv(x + damp * np.eye(size))
- self.assertAllClose(sess.run(tf_inv), np_inv)
-
- def testPosDefInvMatrixInverse(self):
- with ops.Graph().as_default(), self.cached_session() as sess:
- random_seed.set_random_seed(200)
- npr.seed(0)
- square = lambda x: np.dot(x, x.T)
-
- size = 3
- x = square(npr.randn(size, size))
- damp = 0.1
- identity = linalg_ops.eye(size, dtype=dtypes.float64)
-
- tf_inv = utils.posdef_inv_matrix_inverse(
- array_ops.constant(x), identity, damp)
- np_inv = np.linalg.inv(x + damp * np.eye(size))
- self.assertAllClose(sess.run(tf_inv), np_inv)
-
- def testCrossReplicaMean(self):
- """Ensures that cross_replica_mean() executes only when num_shards > 1."""
- with ops.Graph().as_default():
- with tpu_function.tpu_shard_context(4):
- tensor = array_ops.zeros([], dtype=dtypes.float32)
- mean = utils.cross_replica_mean(tensor)
- self.assertNotEqual(mean, tensor)
-
- with ops.Graph().as_default():
- with tpu_function.tpu_shard_context(1):
- tensor = array_ops.zeros([], dtype=dtypes.float32)
- mean = utils.cross_replica_mean(tensor)
- self.assertEqual(mean, tensor)
-
- with ops.Graph().as_default():
- with self.assertRaises(ValueError): # Outside of TPU context.
- tensor = array_ops.zeros([], dtype=dtypes.float32)
- mean = utils.cross_replica_mean(tensor)
-
- def testBatchExecute(self):
- """Ensure batch_execute runs in a round-robin fashion."""
-
- def increment_var(var):
- return lambda: var.assign_add(1)
-
- with ops.Graph().as_default(), self.cached_session() as sess:
- i = variable_scope.get_variable('i', initializer=0)
- accumulators = [
- variable_scope.get_variable('var%d' % j, initializer=0)
- for j in range(3)
- ]
- thunks = [increment_var(var) for var in accumulators]
- increment_accumulators = utils.batch_execute(i, thunks, 2)
- increment_i = i.assign_add(1)
-
- sess.run(variables.global_variables_initializer())
-
- # Ensure one op per thunk.
- self.assertEqual(3, len(increment_accumulators))
-
- # Ensure round-robin execution.
- values = []
- for _ in range(5):
- sess.run(increment_accumulators)
- sess.run(increment_i)
- values.append(sess.run(accumulators))
- self.assertAllClose(
- [
- [1, 1, 0], #
- [2, 1, 1], #
- [2, 2, 2], #
- [3, 3, 2], #
- [4, 3, 3]
- ],
- values)
-
- def testExtractConvolutionPatches(self):
- with ops.Graph().as_default(), self.cached_session() as sess:
- batch_size = 10
- image_spatial_shape = [9, 10, 11]
- in_channels = out_channels = 32
- kernel_spatial_shape = [5, 3, 3]
- spatial_strides = [1, 2, 1]
- spatial_dilation = [1, 1, 1]
- padding = 'SAME'
-
- images = random_ops.random_uniform(
- [batch_size] + image_spatial_shape + [in_channels], seed=0)
- kernel_shape = kernel_spatial_shape + [in_channels, out_channels]
- kernel = random_ops.random_uniform(kernel_shape, seed=1)
-
- # Ensure shape matches expectation.
- patches = utils.extract_convolution_patches(
- images,
- kernel_shape,
- padding,
- strides=spatial_strides,
- dilation_rate=spatial_dilation)
- result_spatial_shape = (
- patches.shape.as_list()[1:1 + len(image_spatial_shape)])
- self.assertEqual(patches.shape.as_list(),
- [batch_size] + result_spatial_shape +
- kernel_spatial_shape + [in_channels])
-
- # Ensure extract...patches() + matmul() and convolution() implementation
- # give the same answer.
- outputs = nn_ops.convolution(
- images,
- kernel,
- padding,
- strides=spatial_strides,
- dilation_rate=spatial_dilation)
-
- patches_flat = array_ops.reshape(
- patches, [-1, np.prod(kernel_spatial_shape) * in_channels])
- kernel_flat = array_ops.reshape(kernel, [-1, out_channels])
- outputs_flat = math_ops.matmul(patches_flat, kernel_flat)
-
- outputs_, outputs_flat_ = sess.run([outputs, outputs_flat])
- self.assertAllClose(outputs_.flatten(), outputs_flat_.flatten())
-
- def testExtractPointwiseConv2dPatches(self):
- with ops.Graph().as_default(), self.cached_session() as sess:
- batch_size = 10
- image_height = image_width = 8
- in_channels = out_channels = 3
- kernel_height = kernel_width = 1
- strides = [1, 1, 1, 1]
- padding = 'VALID'
-
- images = random_ops.random_uniform(
- [batch_size, image_height, image_width, in_channels], seed=0)
- kernel_shape = [kernel_height, kernel_width, in_channels, out_channels]
- kernel = random_ops.random_uniform(kernel_shape, seed=1)
-
- # Ensure shape matches expectation.
- patches = utils.extract_pointwise_conv2d_patches(images, kernel_shape)
- self.assertEqual(patches.shape.as_list(), [
- batch_size, image_height, image_width, kernel_height, kernel_width,
- in_channels
- ])
-
- # Ensure extract...patches() + matmul() and conv2d() implementation
- # give the same answer.
- outputs = nn_ops.conv2d(images, kernel, strides, padding)
-
- patches_flat = array_ops.reshape(
- patches, [-1, kernel_height * kernel_width * in_channels])
- kernel_flat = array_ops.reshape(kernel, [-1, out_channels])
- outputs_flat = math_ops.matmul(patches_flat, kernel_flat)
-
- outputs_, outputs_flat_ = sess.run([outputs, outputs_flat])
- self.assertAllClose(outputs_.flatten(), outputs_flat_.flatten())
-
-
-if __name__ == '__main__':
- test.main()
diff --git a/tensorflow/contrib/kfac/python/ops/BUILD b/tensorflow/contrib/kfac/python/ops/BUILD
deleted file mode 100644
index 3c01eb65e7..0000000000
--- a/tensorflow/contrib/kfac/python/ops/BUILD
+++ /dev/null
@@ -1,263 +0,0 @@
-package(default_visibility = [
- "//tensorflow/contrib/kfac:__pkg__",
- "//tensorflow/contrib/kfac/python/kernel_tests:__pkg__",
-])
-
-licenses(["notice"]) # Apache 2.0
-
-exports_files(["LICENSE"])
-
-py_library(
- name = "fisher_blocks",
- srcs = ["fisher_blocks.py"],
- srcs_version = "PY2AND3",
- deps = [
- ":fisher_factors",
- ":utils",
- "//tensorflow/python:array_ops",
- "//tensorflow/python:math_ops",
- "@six_archive//:six",
- ],
-)
-
-py_library(
- name = "fisher_blocks_lib",
- srcs = ["fisher_blocks_lib.py"],
- srcs_version = "PY2AND3",
- deps = [
- ":fisher_blocks",
- "//tensorflow/python:util",
- ],
-)
-
-py_library(
- name = "fisher_factors",
- srcs = ["fisher_factors.py"],
- srcs_version = "PY2AND3",
- deps = [
- ":linear_operator",
- ":utils",
- "//tensorflow/python:array_ops",
- "//tensorflow/python:control_flow_ops",
- "//tensorflow/python:dtypes",
- "//tensorflow/python:framework_ops",
- "//tensorflow/python:init_ops",
- "//tensorflow/python:linalg_ops",
- "//tensorflow/python:math_ops",
- "//tensorflow/python:random_ops",
- "//tensorflow/python:special_math_ops",
- "//tensorflow/python:training",
- "//tensorflow/python:variable_scope",
- "//tensorflow/python:variables",
- "//third_party/py/numpy",
- "@six_archive//:six",
- ],
-)
-
-py_library(
- name = "fisher_factors_lib",
- srcs = ["fisher_factors_lib.py"],
- srcs_version = "PY2AND3",
- deps = [
- ":fisher_factors",
- "//tensorflow/python:util",
- ],
-)
-
-py_library(
- name = "linear_operator",
- srcs = ["linear_operator.py"],
- srcs_version = "PY2AND3",
- deps = [
- ":utils",
- "//tensorflow/python:framework_ops",
- "//tensorflow/python:math_ops",
- "//tensorflow/python/ops/linalg",
- "@six_archive//:six",
- ],
-)
-
-py_library(
- name = "loss_functions",
- srcs = ["loss_functions.py"],
- srcs_version = "PY2AND3",
- deps = [
- "//tensorflow/contrib/distributions:distributions_py",
- "//tensorflow/python:array_ops",
- "//tensorflow/python:math_ops",
- "//tensorflow/python:tensor_shape",
- "//tensorflow/python/ops/distributions",
- "@six_archive//:six",
- ],
-)
-
-py_library(
- name = "loss_functions_lib",
- srcs = ["loss_functions_lib.py"],
- srcs_version = "PY2AND3",
- deps = [
- ":loss_functions",
- "//tensorflow/python:util",
- ],
-)
-
-py_library(
- name = "curvature_matrix_vector_products",
- srcs = ["curvature_matrix_vector_products.py"],
- srcs_version = "PY2AND3",
- deps = [
- ":utils",
- "//tensorflow/python:gradients",
- "//tensorflow/python:math_ops",
- "//tensorflow/python:util",
- ],
-)
-
-py_library(
- name = "curvature_matrix_vector_products_lib",
- srcs = ["curvature_matrix_vector_products_lib.py"],
- srcs_version = "PY2AND3",
- deps = [
- ":curvature_matrix_vector_products",
- "//tensorflow/python:util",
- ],
-)
-
-py_library(
- name = "layer_collection",
- srcs = ["layer_collection.py"],
- srcs_version = "PY2AND3",
- deps = [
- ":fisher_blocks",
- ":loss_functions",
- ":utils",
- "//tensorflow/python:framework_ops",
- "//tensorflow/python:math_ops",
- "//tensorflow/python:platform",
- "//tensorflow/python:util",
- "//tensorflow/python:variable_scope",
- "@six_archive//:six",
- ],
-)
-
-py_library(
- name = "layer_collection_lib",
- srcs = ["layer_collection_lib.py"],
- srcs_version = "PY2AND3",
- deps = [
- ":layer_collection",
- "//tensorflow/python:util",
- ],
-)
-
-py_library(
- name = "kfac_optimizer",
- srcs = [
- "optimizer.py",
- ],
- srcs_version = "PY2AND3",
- deps = [
- ":curvature_matrix_vector_products",
- ":fisher_estimator",
- "//tensorflow/python:array_ops",
- "//tensorflow/python:control_flow_ops",
- "//tensorflow/python:dtypes",
- "//tensorflow/python:framework_ops",
- "//tensorflow/python:linalg_ops",
- "//tensorflow/python:math_ops",
- "//tensorflow/python:state_ops",
- "//tensorflow/python:training",
- "//tensorflow/python:variable_scope",
- "//tensorflow/python:variables",
- ],
-)
-
-py_library(
- name = "kfac_optimizer_lib",
- srcs = [
- "optimizer_lib.py",
- ],
- srcs_version = "PY2AND3",
- deps = [
- ":kfac_optimizer",
- "//tensorflow/python:util",
- ],
-)
-
-py_library(
- name = "fisher_estimator",
- srcs = [
- "estimator.py",
- "placement.py",
- ],
- srcs_version = "PY2AND3",
- deps = [
- ":utils",
- "//tensorflow/python:control_flow_ops",
- "//tensorflow/python:framework_ops",
- "//tensorflow/python:gradients",
- "//tensorflow/python:util",
- "//third_party/py/numpy",
- "@six_archive//:six",
- ],
-)
-
-py_library(
- name = "fisher_estimator_lib",
- srcs = [
- "estimator_lib.py",
- ],
- srcs_version = "PY2AND3",
- deps = [
- ":fisher_estimator",
- "//tensorflow/python:util",
- ],
-)
-
-py_library(
- name = "utils",
- srcs = ["utils.py"],
- srcs_version = "PY2AND3",
- deps = [
- "//tensorflow/contrib/tpu",
- "//tensorflow/python:array_ops",
- "//tensorflow/python:control_flow_ops",
- "//tensorflow/python:dtypes",
- "//tensorflow/python:framework_ops",
- "//tensorflow/python:gradients",
- "//tensorflow/python:linalg_ops",
- "//tensorflow/python:math_ops",
- "//tensorflow/python:random_ops",
- "//third_party/py/numpy",
- ],
-)
-
-py_library(
- name = "utils_lib",
- srcs = ["utils_lib.py"],
- srcs_version = "PY2AND3",
- deps = [
- ":utils",
- "//tensorflow/python:util",
- ],
-)
-
-py_library(
- name = "op_queue",
- srcs = ["op_queue.py"],
- srcs_version = "PY2AND3",
- deps = [
- "//tensorflow/contrib/data/python/ops:dataset_ops",
- "//tensorflow/python:framework_ops",
- ],
-)
-
-py_library(
- name = "op_queue_lib",
- srcs = ["op_queue_lib.py"],
- srcs_version = "PY2AND3",
- deps = [
- ":op_queue",
- "//tensorflow/python:util",
- ],
-)
diff --git a/tensorflow/contrib/kfac/python/ops/curvature_matrix_vector_products.py b/tensorflow/contrib/kfac/python/ops/curvature_matrix_vector_products.py
deleted file mode 100644
index 21b5cde9b9..0000000000
--- a/tensorflow/contrib/kfac/python/ops/curvature_matrix_vector_products.py
+++ /dev/null
@@ -1,183 +0,0 @@
-# Copyright 2017 The TensorFlow Authors. All Rights Reserved.
-#
-# Licensed under the Apache License, Version 2.0 (the "License");
-# you may not use this file except in compliance with the License.
-# You may obtain a copy of the License at
-#
-# http://www.apache.org/licenses/LICENSE-2.0
-#
-# Unless required by applicable law or agreed to in writing, software
-# distributed under the License is distributed on an "AS IS" BASIS,
-# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-# See the License for the specific language governing permissions and
-# limitations under the License.
-# ==============================================================================
-"""Curvature matrix-vector multiplication."""
-
-from __future__ import absolute_import
-from __future__ import division
-from __future__ import print_function
-
-from tensorflow.contrib.kfac.python.ops import utils
-from tensorflow.python.ops import gradients_impl
-from tensorflow.python.ops import math_ops
-from tensorflow.python.util import nest
-
-
-class CurvatureMatrixVectorProductComputer(object):
- """Class for computing matrix-vector products for Fishers, GGNs and Hessians.
-
- In other words we compute M*v where M is the matrix, v is the vector, and
- * refers to standard matrix/vector multiplication (not element-wise
- multiplication).
-
- The matrices are defined in terms of some differential quantity of the total
- loss function with respect to a provided list of tensors ("wrt_tensors").
- For example, the Fisher associated with a log-prob loss w.r.t. the
- parameters.
-
- The 'vecs' argument to each method are lists of tensors that must be the
- size as the corresponding ones from "wrt_tensors". They represent
- the vector being multiplied.
-
- "factors" of the matrix M are defined as matrices B such that B*B^T = M.
- Methods that multiply by the factor B take a 'loss_inner_vecs' argument
- instead of 'vecs', which must be a list of tensors with shapes given by the
- corresponding XXX_inner_shapes property.
-
- Note that matrix-vector products are not normalized by the batch size, nor
- are any damping terms added to the results. These things can be easily
- applied externally, if desired.
-
- See for example: www.cs.utoronto.ca/~jmartens/docs/HF_book_chapter.pdf
- and https://arxiv.org/abs/1412.1193 for more information about the
- generalized Gauss-Newton, Fisher, etc., and how to compute matrix-vector
- products.
- """
-
- def __init__(self, losses, wrt_tensors):
- """Create a CurvatureMatrixVectorProductComputer object.
-
- Args:
- losses: A list of LossFunction instances whose sum defines the total loss.
- wrt_tensors: A list of Tensors to compute the differential quantities
- (defining the matrices) with respect to. See class description for more
- info.
- """
- self._losses = losses
- self._inputs_to_losses = list(loss.inputs for loss in losses)
- self._inputs_to_losses_flat = nest.flatten(self._inputs_to_losses)
- self._wrt_tensors = wrt_tensors
-
- @property
- def _total_loss(self):
- return math_ops.add_n(tuple(loss.evaluate() for loss in self._losses))
-
- # Jacobian multiplication functions:
- def _multiply_jacobian(self, vecs):
- """Multiply vecs by the Jacobian of losses."""
- # We stop gradients at wrt_tensors to produce partial derivatives (which is
- # what we want for Jacobians).
- jacobian_vecs_flat = utils.fwd_gradients(
- self._inputs_to_losses_flat, self._wrt_tensors, grad_xs=vecs,
- stop_gradients=self._wrt_tensors)
- return nest.pack_sequence_as(self._inputs_to_losses, jacobian_vecs_flat)
-
- def _multiply_jacobian_transpose(self, loss_vecs):
- """Multiply vecs by the transpose Jacobian of losses."""
- loss_vecs_flat = nest.flatten(loss_vecs)
- # We stop gradients at wrt_tensors to produce partial derivatives (which is
- # what we want for Jacobians).
- return gradients_impl.gradients(
- self._inputs_to_losses_flat, self._wrt_tensors, grad_ys=loss_vecs_flat,
- stop_gradients=self._wrt_tensors)
-
- # Losses Fisher/Hessian multiplication functions:
- def _multiply_loss_fisher(self, loss_vecs):
- """Multiply loss_vecs by Fisher of total loss."""
- return tuple(
- loss.multiply_fisher(loss_vec)
- for loss, loss_vec in zip(self._losses, loss_vecs))
-
- def _multiply_loss_fisher_factor(self, loss_inner_vecs):
- """Multiply loss_inner_vecs by factor of Fisher of total loss."""
- return tuple(
- loss.multiply_fisher_factor(loss_vec)
- for loss, loss_vec in zip(self._losses, loss_inner_vecs))
-
- def _multiply_loss_fisher_factor_transpose(self, loss_vecs):
- """Multiply loss_vecs by transpose factor of Fisher of total loss."""
- return tuple(
- loss.multiply_fisher_factor_transpose(loss_vec)
- for loss, loss_vec in zip(self._losses, loss_vecs))
-
- def _multiply_loss_hessian(self, loss_vecs):
- """Multiply loss_vecs by Hessian of total loss."""
- return tuple(
- loss.multiply_hessian(loss_vec)
- for loss, loss_vec in zip(self._losses, loss_vecs))
-
- def _multiply_loss_hessian_factor(self, loss_inner_vecs):
- """Multiply loss_inner_vecs by factor of Hessian of total loss."""
- return tuple(
- loss.multiply_hessian_factor(loss_vec)
- for loss, loss_vec in zip(self._losses, loss_inner_vecs))
-
- def _multiply_loss_hessian_factor_transpose(self, loss_vecs):
- """Multiply loss_vecs by transpose factor of Hessian of total loss."""
- return tuple(
- loss.multiply_hessian_factor_transpose(loss_vec)
- for loss, loss_vec in zip(self._losses, loss_vecs))
-
- # Matrix-vector product functions:
- def multiply_fisher(self, vecs):
- """Multiply vecs by Fisher of total loss."""
- jacobian_vecs = self._multiply_jacobian(vecs)
- loss_fisher_jacobian_vecs = self._multiply_loss_fisher(jacobian_vecs)
- return self._multiply_jacobian_transpose(loss_fisher_jacobian_vecs)
-
- def multiply_fisher_factor_transpose(self, vecs):
- """Multiply vecs by transpose of factor of Fisher of total loss."""
- jacobian_vecs = self._multiply_jacobian(vecs)
- return self._multiply_loss_fisher_factor_transpose(jacobian_vecs)
-
- def multiply_fisher_factor(self, loss_inner_vecs):
- """Multiply loss_inner_vecs by factor of Fisher of total loss."""
- fisher_factor_transpose_vecs = self._multiply_loss_fisher_factor_transpose(
- loss_inner_vecs)
- return self._multiply_jacobian_transpose(fisher_factor_transpose_vecs)
-
- def multiply_hessian(self, vecs):
- """Multiply vecs by Hessian of total loss."""
- return gradients_impl.gradients(
- gradients_impl.gradients(self._total_loss, self._wrt_tensors),
- self._wrt_tensors,
- grad_ys=vecs)
-
- def multiply_generalized_gauss_newton(self, vecs):
- """Multiply vecs by generalized Gauss-Newton of total loss."""
- jacobian_vecs = self._multiply_jacobian(vecs)
- loss_hessian_jacobian_vecs = self._multiply_loss_hessian(jacobian_vecs)
- return self._multiply_jacobian_transpose(loss_hessian_jacobian_vecs)
-
- def multiply_generalized_gauss_newton_factor_transpose(self, vecs):
- """Multiply vecs by transpose of factor of GGN of total loss."""
- jacobian_vecs = self._multiply_jacobian(vecs)
- return self._multiply_loss_hessian_factor_transpose(jacobian_vecs)
-
- def multiply_generalized_gauss_newton_factor(self, loss_inner_vecs):
- """Multiply loss_inner_vecs by factor of GGN of total loss."""
- hessian_factor_transpose_vecs = (
- self._multiply_loss_hessian_factor_transpose(loss_inner_vecs))
- return self._multiply_jacobian_transpose(hessian_factor_transpose_vecs)
-
- # Shape properties for multiply_XXX_factor methods:
- @property
- def fisher_factor_inner_shapes(self):
- """Shapes required by multiply_fisher_factor."""
- return tuple(loss.fisher_factor_inner_shape for loss in self._losses)
-
- @property
- def generalized_gauss_newton_factor_inner_shapes(self):
- """Shapes required by multiply_generalized_gauss_newton_factor."""
- return tuple(loss.hessian_factor_inner_shape for loss in self._losses)
diff --git a/tensorflow/contrib/kfac/python/ops/curvature_matrix_vector_products_lib.py b/tensorflow/contrib/kfac/python/ops/curvature_matrix_vector_products_lib.py
deleted file mode 100644
index 6e8c6404dc..0000000000
--- a/tensorflow/contrib/kfac/python/ops/curvature_matrix_vector_products_lib.py
+++ /dev/null
@@ -1,30 +0,0 @@
-# Copyright 2017 The TensorFlow Authors. All Rights Reserved.
-#
-# Licensed under the Apache License, Version 2.0 (the "License");
-# you may not use this file except in compliance with the License.
-# You may obtain a copy of the License at
-#
-# http://www.apache.org/licenses/LICENSE-2.0
-#
-# Unless required by applicable law or agreed to in writing, software
-# distributed under the License is distributed on an "AS IS" BASIS,
-# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-# See the License for the specific language governing permissions and
-# limitations under the License.
-# ==============================================================================
-"""Curvature matrix-vector multiplication."""
-
-from __future__ import absolute_import
-from __future__ import division
-from __future__ import print_function
-
-# pylint: disable=unused-import,line-too-long,wildcard-import
-from tensorflow.contrib.kfac.python.ops.curvature_matrix_vector_products import *
-from tensorflow.python.util.all_util import remove_undocumented
-# pylint: enable=unused-import,line-too-long,wildcard-import
-
-_allowed_symbols = [
- 'CurvatureMatrixVectorProductComputer',
-]
-
-remove_undocumented(__name__, allowed_exception_list=_allowed_symbols)
diff --git a/tensorflow/contrib/kfac/python/ops/estimator.py b/tensorflow/contrib/kfac/python/ops/estimator.py
deleted file mode 100644
index 323234c403..0000000000
--- a/tensorflow/contrib/kfac/python/ops/estimator.py
+++ /dev/null
@@ -1,516 +0,0 @@
-# Copyright 2017 The TensorFlow Authors. All Rights Reserved.
-#
-# Licensed under the Apache License, Version 2.0 (the "License");
-# you may not use this file except in compliance with the License.
-# You may obtain a copy of the License at
-#
-# http://www.apache.org/licenses/LICENSE-2.0
-#
-# Unless required by applicable law or agreed to in writing, software
-# distributed under the License is distributed on an "AS IS" BASIS,
-# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-# See the License for the specific language governing permissions and
-# limitations under the License.
-# ==============================================================================
-"""Defines the high-level Fisher estimator class."""
-
-from __future__ import absolute_import
-from __future__ import division
-from __future__ import print_function
-
-import abc
-import numpy as np
-import six
-
-from tensorflow.contrib.kfac.python.ops import placement
-from tensorflow.contrib.kfac.python.ops import utils
-from tensorflow.python.framework import ops as tf_ops
-from tensorflow.python.ops import control_flow_ops
-from tensorflow.python.ops import gradients_impl
-from tensorflow.python.ops import variable_scope
-from tensorflow.python.util import nest
-
-
-# The linter is confused.
-# pylint: disable=abstract-class-instantiated
-def make_fisher_estimator(placement_strategy=None, **kwargs):
- """Creates Fisher estimator instances based on the placement strategy.
-
- For example if the `placement_strategy` is 'round_robin' then
- `FisherEstimatorRoundRobin` instance is returned.
-
- Args:
- placement_strategy: `string`, Strategy to be used for placing covariance
- variables, covariance ops and inverse ops. Check
- `placement.FisherEstimatorRoundRobin` for a concrete example.
- **kwargs: Arguments to be passed into `FisherEstimator` class initializer.
-
- Returns:
- An instance of class which inherits from `FisherEstimator` and the mixin
- which implements specific placement strategy. See,
- `FisherEstimatorRoundRobin` which inherits from `FisherEstimator` and
- `RoundRobinPlacementMixin`.
-
- Raises:
- ValueError: If the `placement_strategy` is not equal to 'round_robin'.
- """
- if placement_strategy in [None, "round_robin"]:
- return FisherEstimatorRoundRobin(**kwargs)
- else:
- raise ValueError("Unimplemented vars and ops "
- "placement strategy : {}".format(placement_strategy))
-# pylint: enable=abstract-class-instantiated
-
-
-@six.add_metaclass(abc.ABCMeta)
-class FisherEstimator(object):
- """Fisher estimator class supporting various approximations of the Fisher.
-
- This is an abstract base class which does not implement a strategy for
- placing covariance variables, covariance update ops and inverse update ops.
- The placement strategies are implemented in `placement.py`. See
- `FisherEstimatorRoundRobin` for example of a concrete subclass with
- a round-robin placement strategy.
- """
-
- def __init__(self,
- variables,
- cov_ema_decay,
- damping,
- layer_collection,
- exps=(-1,),
- estimation_mode="gradients",
- colocate_gradients_with_ops=True,
- name="FisherEstimator",
- compute_cholesky=False,
- compute_cholesky_inverse=False):
- """Create a FisherEstimator object.
-
- Args:
- variables: A `list` of variables or `callable` which returns the variables
- for which to estimate the Fisher. This must match the variables
- registered in layer_collection (if it is not None).
- cov_ema_decay: The decay factor used when calculating the covariance
- estimate moving averages.
- damping: float. The damping factor used to stabilize training due to
- errors in the local approximation with the Fisher information matrix,
- and to regularize the update direction by making it closer to the
- gradient. (Higher damping means the update looks more like a standard
- gradient update - see Tikhonov regularization.)
- layer_collection: The layer collection object, which holds the Fisher
- blocks, Kronecker factors, and losses associated with the
- graph.
- exps: List of floats or ints. These represent the different matrix
- powers of the approximate Fisher that the FisherEstimator will be able
- to multiply vectors by. If the user asks for a matrix power other
- one of these (or 1, which is always supported), there will be a
- failure. (Default: (-1,))
- estimation_mode: The type of estimator to use for the Fishers. Can be
- 'gradients', 'empirical', 'curvature_prop', or 'exact'.
- (Default: 'gradients'). 'gradients' is the basic estimation approach
- from the original K-FAC paper. 'empirical' computes the 'empirical'
- Fisher information matrix (which uses the data's distribution for the
- targets, as opposed to the true Fisher which uses the model's
- distribution) and requires that each registered loss have specified
- targets. 'curvature_propagation' is a method which estimates the
- Fisher using self-products of random 1/-1 vectors times "half-factors"
- of the Fisher, as described here: https://arxiv.org/abs/1206.6464 .
- Finally, 'exact' is the obvious generalization of Curvature
- Propagation to compute the exact Fisher (modulo any additional
- diagonal or Kronecker approximations) by looping over one-hot vectors
- for each coordinate of the output instead of using 1/-1 vectors. It
- is more expensive to compute than the other three options by a factor
- equal to the output dimension, roughly speaking.
- colocate_gradients_with_ops: Whether we should request gradients be
- colocated with their respective ops. (Default: True)
- name: A string. A name given to this estimator, which is added to the
- variable scope when constructing variables and ops.
- (Default: "FisherEstimator")
- compute_cholesky: Bool. Whether or not the FisherEstimator will be
- able to multiply vectors by the Cholesky factor.
- (Default: False)
- compute_cholesky_inverse: Bool. Whether or not the FisherEstimator
- will be able to multiply vectors by the Cholesky factor inverse.
- (Default: False)
- Raises:
- ValueError: If no losses have been registered with layer_collection.
- """
- self._variables = variables
- self._cov_ema_decay = cov_ema_decay
- self._damping = damping
- self._estimation_mode = estimation_mode
- self._layers = layer_collection
- self._gradient_fns = {
- "gradients": self._get_grads_lists_gradients,
- "empirical": self._get_grads_lists_empirical,
- "curvature_prop": self._get_grads_lists_curvature_prop,
- "exact": self._get_grads_lists_exact
- }
- self._colocate_gradients_with_ops = colocate_gradients_with_ops
-
- self._made_vars = False
- self._exps = exps
- self._compute_cholesky = compute_cholesky
- self._compute_cholesky_inverse = compute_cholesky_inverse
-
- self._name = name
-
- @property
- def variables(self):
- if callable(self._variables):
- return self._variables()
- else:
- return self._variables
-
- @property
- def damping(self):
- return self._damping
-
- @property
- def blocks(self):
- """All registered FisherBlocks."""
- return self._layers.get_blocks()
-
- @property
- def factors(self):
- """All registered FisherFactors."""
- return self._layers.get_factors()
-
- @property
- def name(self):
- return self._name
-
- @abc.abstractmethod
- def make_vars_and_create_op_thunks(self, scope=None):
- """Make vars and create op thunks with a specific placement strategy.
-
- For each factor, all of that factor's cov variables and their associated
- update ops will be placed on a particular device. A new device is chosen
- for each factor by cycling through list of devices in the cov_devices
- argument. If cov_devices is None then no explicit device placement occurs.
-
- An analogous strategy is followed for inverse update ops, with the list of
- devices being given by the inv_devices argument.
-
- Inverse variables on the other hand are not placed on any specific device
- (they will just use the current the device placement context, whatever
- that happens to be). The idea is that the inverse variable belong where
- they will be accessed most often, which is the device that actually applies
- the preconditioner to the gradient. The user will be responsible for setting
- the device context for this.
-
- Args:
- scope: A string or None. If None it will be set to the name of this
- estimator (given by the name property). All variables will be created,
- and all thunks will execute, inside of a variable scope of the given
- name. (Default: None)
-
- Returns:
- cov_update_thunks: List of cov update thunks. Corresponds one-to-one with
- the list of factors given by the "factors" property.
- inv_update_thunks: List of inv update thunks. Corresponds one-to-one with
- the list of factors given by the "factors" property.
- """
- pass
-
- def _apply_transformation(self, vecs_and_vars, transform):
- """Applies an block-wise transformation to the corresponding vectors.
-
- Args:
- vecs_and_vars: List of (vector, variable) pairs.
- transform: A function of the form f(fb, vec), where vec is the vector
- to transform and fb is its corresponding block in the matrix, that
- returns the transformed vector.
-
- Returns:
- A list of (transformed vector, var) pairs in the same order as
- vecs_and_vars.
- """
-
- vecs = utils.SequenceDict((var, vec) for vec, var in vecs_and_vars)
-
- trans_vecs = utils.SequenceDict()
-
- for params, fb in self._layers.fisher_blocks.items():
- trans_vecs[params] = transform(fb, vecs[params])
-
- return [(trans_vecs[var], var) for _, var in vecs_and_vars]
-
- def multiply_inverse(self, vecs_and_vars):
- """Multiplies the vecs by the corresponding (damped) inverses of the blocks.
-
- Args:
- vecs_and_vars: List of (vector, variable) pairs.
-
- Returns:
- A list of (transformed vector, var) pairs in the same order as
- vecs_and_vars.
- """
- return self.multiply_matpower(-1, vecs_and_vars)
-
- def multiply(self, vecs_and_vars):
- """Multiplies the vectors by the corresponding (damped) blocks.
-
- Args:
- vecs_and_vars: List of (vector, variable) pairs.
-
- Returns:
- A list of (transformed vector, var) pairs in the same order as
- vecs_and_vars.
- """
- return self.multiply_matpower(1, vecs_and_vars)
-
- def multiply_matpower(self, exp, vecs_and_vars):
- """Multiplies the vecs by the corresponding matrix powers of the blocks.
-
- Args:
- exp: A float representing the power to raise the blocks by before
- multiplying it by the vector.
- vecs_and_vars: List of (vector, variable) pairs.
-
- Returns:
- A list of (transformed vector, var) pairs in the same order as
- vecs_and_vars.
- """
- assert exp in self._exps
-
- fcn = lambda fb, vec: fb.multiply_matpower(vec, exp)
- return self._apply_transformation(vecs_and_vars, fcn)
-
- def multiply_cholesky(self, vecs_and_vars, transpose=False):
- """Multiplies the vecs by the corresponding Cholesky factors.
-
- Args:
- vecs_and_vars: List of (vector, variable) pairs.
- transpose: Bool. If true the Cholesky factors are transposed before
- multiplying the vecs. (Default: False)
-
- Returns:
- A list of (transformed vector, var) pairs in the same order as
- vecs_and_vars.
- """
- assert self._compute_cholesky
-
- fcn = lambda fb, vec: fb.multiply_cholesky(vec, transpose=transpose)
- return self._apply_transformation(vecs_and_vars, fcn)
-
- def multiply_cholesky_inverse(self, vecs_and_vars, transpose=False):
- """Mults the vecs by the inverses of the corresponding Cholesky factors.
-
- Note: if you are using Cholesky inverse multiplication to sample from
- a matrix-variate Gaussian you will want to multiply by the transpose.
- Let L be the Cholesky factor of F and observe that
-
- L^-T * L^-1 = (L * L^T)^-1 = F^-1 .
-
- Thus we want to multiply by L^-T in order to sample from Gaussian with
- covariance F^-1.
-
- Args:
- vecs_and_vars: List of (vector, variable) pairs.
- transpose: Bool. If true the Cholesky factor inverses are transposed
- before multiplying the vecs. (Default: False)
-
- Returns:
- A list of (transformed vector, var) pairs in the same order as
- vecs_and_vars.
- """
- assert self._compute_cholesky_inverse
-
- fcn = lambda fb, vec: fb.multiply_cholesky_inverse(vec, transpose=transpose)
- return self._apply_transformation(vecs_and_vars, fcn)
-
- def _instantiate_factors(self):
- """Instantiates FisherFactors' variables.
-
- Raises:
- ValueError: If estimation_mode was improperly specified at construction.
- """
- blocks = self.blocks
- tensors_to_compute_grads = [
- block.tensors_to_compute_grads() for block in blocks
- ]
-
- try:
- grads_lists = self._gradient_fns[self._estimation_mode](
- tensors_to_compute_grads)
- except KeyError:
- raise ValueError("Unrecognized value {} for estimation_mode.".format(
- self._estimation_mode))
-
- for grads_list, block in zip(grads_lists, blocks):
- block.instantiate_factors(grads_list, self.damping)
-
- def _check_vars_unmade_and_set_made_flag(self):
- if self._made_vars:
- raise Exception("Already made variables.")
- self._made_vars = True
-
- def made_vars(self):
- return self._made_vars
-
- def _register_matrix_functions(self):
- for block in self.blocks:
- for exp in self._exps:
- block.register_matpower(exp)
- if self._compute_cholesky:
- block.register_cholesky()
- if self._compute_cholesky_inverse:
- block.register_cholesky_inverse()
-
- def _finalize_layer_collection(self):
- self._layers.create_subgraph()
- self._layers.check_registration(self.variables)
- self._instantiate_factors()
- self._register_matrix_functions()
-
- def create_ops_and_vars_thunks(self, scope=None):
- """Create thunks that make the ops and vars on demand.
-
- This function returns 4 lists of thunks: cov_variable_thunks,
- cov_update_thunks, inv_variable_thunks, and inv_update_thunks.
-
- The length of each list is the number of factors and the i-th element of
- each list corresponds to the i-th factor (given by the "factors" property).
-
- Note that the execution of these thunks must happen in a certain
- partial order. The i-th element of cov_variable_thunks must execute
- before the i-th element of cov_update_thunks (and also the i-th element
- of inv_update_thunks). Similarly, the i-th element of inv_variable_thunks
- must execute before the i-th element of inv_update_thunks.
-
- TL;DR (oversimplified): Execute the thunks according to the order that
- they are returned.
-
- Args:
- scope: A string or None. If None it will be set to the name of this
- estimator (given by the name property). All thunks will execute inside
- of a variable scope of the given name. (Default: None)
- Returns:
- cov_variable_thunks: A list of thunks that make the cov variables.
- cov_update_thunks: A list of thunks that make the cov update ops.
- inv_variable_thunks: A list of thunks that make the inv variables.
- inv_update_thunks: A list of thunks that make the inv update ops.
- """
- self._check_vars_unmade_and_set_made_flag()
-
- self._finalize_layer_collection()
-
- scope = self.name if scope is None else scope
-
- cov_variable_thunks = [
- self._create_cov_variable_thunk(factor, scope)
- for factor in self.factors
- ]
- cov_update_thunks = [
- self._create_cov_update_thunk(factor, scope) for factor in self.factors
- ]
- inv_variable_thunks = [
- self._create_inv_variable_thunk(factor, scope)
- for factor in self.factors
- ]
- inv_update_thunks = [
- self._create_inv_update_thunk(factor, scope) for factor in self.factors
- ]
-
- return (cov_variable_thunks, cov_update_thunks,
- inv_variable_thunks, inv_update_thunks)
-
- def _create_cov_variable_thunk(self, factor, scope):
- """Constructs a covariance variable thunk for a single FisherFactor."""
-
- def thunk():
- with variable_scope.variable_scope(scope):
- return factor.instantiate_cov_variables()
-
- return thunk
-
- def _create_cov_update_thunk(self, factor, scope):
- """Constructs a covariance update thunk for a single FisherFactor."""
-
- def thunk():
- with variable_scope.variable_scope(scope):
- return factor.make_covariance_update_op(self._cov_ema_decay)
-
- return thunk
-
- def _create_inv_variable_thunk(self, factor, scope):
- """Constructs a inverse variable thunk for a single FisherFactor."""
-
- def thunk():
- with variable_scope.variable_scope(scope):
- return factor.instantiate_inv_variables()
-
- return thunk
-
- def _create_inv_update_thunk(self, factor, scope):
- """Constructs an inverse update thunk for a single FisherFactor."""
-
- def thunk():
- with variable_scope.variable_scope(scope):
- return control_flow_ops.group(factor.make_inverse_update_ops())
-
- return thunk
-
- def _get_grads_lists_gradients(self, tensors):
- # Passing in a list of loss values is better than passing in the sum as
- # the latter creates unnessesary ops on the default device
- grads_flat = gradients_impl.gradients(
- self._layers.eval_losses_on_samples(),
- nest.flatten(tensors),
- colocate_gradients_with_ops=self._colocate_gradients_with_ops)
- grads_all = nest.pack_sequence_as(tensors, grads_flat)
- return tuple((grad,) for grad in grads_all)
-
- def _get_grads_lists_empirical(self, tensors):
- # Passing in a list of loss values is better than passing in the sum as
- # the latter creates unnecessary ops on the default device
- grads_flat = gradients_impl.gradients(
- self._layers.eval_losses(),
- nest.flatten(tensors),
- colocate_gradients_with_ops=self._colocate_gradients_with_ops)
- grads_all = nest.pack_sequence_as(tensors, grads_flat)
- return tuple((grad,) for grad in grads_all)
-
- def _get_transformed_random_signs(self):
- transformed_random_signs = []
- for loss in self._layers.losses:
- with tf_ops.colocate_with(self._layers.loss_colocation_ops[loss]):
- transformed_random_signs.append(
- loss.multiply_fisher_factor(
- utils.generate_random_signs(loss.fisher_factor_inner_shape)))
- return transformed_random_signs
-
- def _get_grads_lists_curvature_prop(self, tensors):
- loss_inputs = list(loss.inputs for loss in self._layers.losses)
- transformed_random_signs = self._get_transformed_random_signs()
- grads_flat = gradients_impl.gradients(
- nest.flatten(loss_inputs),
- nest.flatten(tensors),
- grad_ys=nest.flatten(transformed_random_signs),
- colocate_gradients_with_ops=self._colocate_gradients_with_ops)
- grads_all = nest.pack_sequence_as(tensors, grads_flat)
- return tuple((grad,) for grad in grads_all)
-
- def _get_grads_lists_exact(self, tensors):
- """No docstring required."""
- # Loop over all coordinates of all losses.
- grads_all = []
- for loss in self._layers.losses:
- with tf_ops.colocate_with(self._layers.loss_colocation_ops[loss]):
- for index in np.ndindex(*loss.fisher_factor_inner_static_shape[1:]):
- transformed_one_hot = loss.multiply_fisher_factor_replicated_one_hot(
- index)
- grads_flat = gradients_impl.gradients(
- loss.inputs,
- nest.flatten(tensors),
- grad_ys=transformed_one_hot,
- colocate_gradients_with_ops=self._colocate_gradients_with_ops)
- grads_all.append(nest.pack_sequence_as(tensors, grads_flat))
- return zip(*grads_all)
-
-
-class FisherEstimatorRoundRobin(placement.RoundRobinPlacementMixin,
- FisherEstimator):
- """Fisher estimator which provides round robin device placement strategy."""
- pass
diff --git a/tensorflow/contrib/kfac/python/ops/estimator_lib.py b/tensorflow/contrib/kfac/python/ops/estimator_lib.py
deleted file mode 100644
index 9c9fef471f..0000000000
--- a/tensorflow/contrib/kfac/python/ops/estimator_lib.py
+++ /dev/null
@@ -1,31 +0,0 @@
-# Copyright 2017 The TensorFlow Authors. All Rights Reserved.
-#
-# Licensed under the Apache License, Version 2.0 (the "License");
-# you may not use this file except in compliance with the License.
-# You may obtain a copy of the License at
-#
-# http://www.apache.org/licenses/LICENSE-2.0
-#
-# Unless required by applicable law or agreed to in writing, software
-# distributed under the License is distributed on an "AS IS" BASIS,
-# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-# See the License for the specific language governing permissions and
-# limitations under the License.
-# ==============================================================================
-"""Defines the high-level Fisher estimator class."""
-
-from __future__ import absolute_import
-from __future__ import division
-from __future__ import print_function
-
-# pylint: disable=unused-import,line-too-long,wildcard-import
-from tensorflow.contrib.kfac.python.ops.estimator import *
-from tensorflow.python.util.all_util import remove_undocumented
-# pylint: enable=unused-import,line-too-long,wildcard-import
-
-_allowed_symbols = [
- 'FisherEstimator',
- 'make_fisher_estimator',
-]
-
-remove_undocumented(__name__, allowed_exception_list=_allowed_symbols)
diff --git a/tensorflow/contrib/kfac/python/ops/fisher_blocks.py b/tensorflow/contrib/kfac/python/ops/fisher_blocks.py
deleted file mode 100644
index 9fa6eb7dcd..0000000000
--- a/tensorflow/contrib/kfac/python/ops/fisher_blocks.py
+++ /dev/null
@@ -1,1752 +0,0 @@
-# Copyright 2017 The TensorFlow Authors. All Rights Reserved.
-#
-# Licensed under the Apache License, Version 2.0 (the "License");
-# you may not use this file except in compliance with the License.
-# You may obtain a copy of the License at
-#
-# http://www.apache.org/licenses/LICENSE-2.0
-#
-# Unless required by applicable law or agreed to in writing, software
-# distributed under the License is distributed on an "AS IS" BASIS,
-# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-# See the License for the specific language governing permissions and
-# limitations under the License.
-# ==============================================================================
-"""FisherBlock definitions.
-
-This library contains classes for estimating blocks in a model's Fisher
-Information matrix. Suppose one has a model that parameterizes a posterior
-distribution over 'y' given 'x' with parameters 'params', p(y | x, params). Its
-Fisher Information matrix is given by,
-
- $$F(params) = E[ v(x, y, params) v(x, y, params)^T ]$$
-
-where,
-
- $$v(x, y, params) = (d / d params) log p(y | x, params)$$
-
-and the expectation is taken with respect to the data's distribution for 'x' and
-the model's posterior distribution for 'y',
-
- x ~ p(x)
- y ~ p(y | x, params)
-
-"""
-
-from __future__ import absolute_import
-from __future__ import division
-from __future__ import print_function
-
-import abc
-import enum # pylint: disable=g-bad-import-order
-
-import numpy as np
-import six
-
-from tensorflow.contrib.kfac.python.ops import fisher_factors
-from tensorflow.contrib.kfac.python.ops import utils
-from tensorflow.python.framework import ops
-from tensorflow.python.ops import array_ops
-from tensorflow.python.ops import math_ops
-from tensorflow.python.util import nest
-
-# For blocks corresponding to convolutional layers, or any type of block where
-# the parameters can be thought of as being replicated in time or space,
-# we want to adjust the scale of the damping by
-# damping /= num_replications ** NORMALIZE_DAMPING_POWER
-NORMALIZE_DAMPING_POWER = 1.0
-
-# Methods for adjusting damping for FisherBlocks. See
-# compute_pi_adjusted_damping() for details.
-PI_OFF_NAME = "off"
-PI_TRACENORM_NAME = "tracenorm"
-PI_TYPE = PI_TRACENORM_NAME
-
-
-def set_global_constants(normalize_damping_power=None, pi_type=None):
- """Sets various global constants used by the classes in this module."""
- global NORMALIZE_DAMPING_POWER
- global PI_TYPE
-
- if normalize_damping_power is not None:
- NORMALIZE_DAMPING_POWER = normalize_damping_power
-
- if pi_type is not None:
- PI_TYPE = pi_type
-
-
-def normalize_damping(damping, num_replications):
- """Normalize damping after adjusting scale by NORMALIZE_DAMPING_POWER."""
- if NORMALIZE_DAMPING_POWER:
- return damping / (num_replications ** NORMALIZE_DAMPING_POWER)
- return damping
-
-
-def compute_pi_tracenorm(left_cov, right_cov):
- r"""Computes the scalar constant pi for Tikhonov regularization/damping.
-
- $$\pi = \sqrt{ (trace(A) / dim(A)) / (trace(B) / dim(B)) }$$
- See section 6.3 of https://arxiv.org/pdf/1503.05671.pdf for details.
-
- Args:
- left_cov: A LinearOperator object. The left Kronecker factor "covariance".
- right_cov: A LinearOperator object. The right Kronecker factor "covariance".
-
- Returns:
- The computed scalar constant pi for these Kronecker Factors (as a Tensor).
- """
- # Instead of dividing by the dim of the norm, we multiply by the dim of the
- # other norm. This works out the same in the ratio.
- left_norm = left_cov.trace() * int(right_cov.domain_dimension)
- right_norm = right_cov.trace() * int(left_cov.domain_dimension)
- return math_ops.sqrt(left_norm / right_norm)
-
-
-def compute_pi_adjusted_damping(left_cov, right_cov, damping):
-
- if PI_TYPE == PI_TRACENORM_NAME:
- pi = compute_pi_tracenorm(left_cov, right_cov)
- return (damping * pi, damping / pi)
-
- elif PI_TYPE == PI_OFF_NAME:
- return (damping, damping)
-
-
-class PackagedFunc(object):
- """A Python thunk with a stable ID.
-
- Enables stable names for lambdas.
- """
-
- def __init__(self, func, func_id):
- """Initializes PackagedFunc.
-
- Args:
- func: a zero-arg Python function.
- func_id: a hashable, function that produces a hashable, or a list/tuple
- thereof.
- """
- self._func = func
- func_id = func_id if isinstance(func_id, (tuple, list)) else (func_id,)
- self._func_id = func_id
-
- def __call__(self):
- return self._func()
-
- @property
- def func_id(self):
- """A hashable identifier for this function."""
- return tuple(elt() if callable(elt) else elt for elt in self._func_id)
-
-
-def _package_func(func, func_id):
- return PackagedFunc(func, func_id)
-
-
-@six.add_metaclass(abc.ABCMeta)
-class FisherBlock(object):
- """Abstract base class for objects modeling approximate Fisher matrix blocks.
-
- Subclasses must implement register_matpower, multiply_matpower,
- instantiate_factors, tensors_to_compute_grads, and num_registered_towers
- methods.
- """
-
- def __init__(self, layer_collection):
- self._layer_collection = layer_collection
-
- @abc.abstractmethod
- def instantiate_factors(self, grads_list, damping):
- """Creates and registers the component factors of this Fisher block.
-
- Args:
- grads_list: A list gradients (each a Tensor or tuple of Tensors) with
- respect to the tensors returned by tensors_to_compute_grads() that
- are to be used to estimate the block.
- damping: The damping factor (float or Tensor).
- """
- pass
-
- @abc.abstractmethod
- def register_matpower(self, exp):
- """Registers a matrix power to be computed by the block.
-
- Args:
- exp: A float representing the power to raise the block by.
- """
- pass
-
- @abc.abstractmethod
- def register_cholesky(self):
- """Registers a Cholesky factor to be computed by the block."""
- pass
-
- @abc.abstractmethod
- def register_cholesky_inverse(self):
- """Registers an inverse Cholesky factor to be computed by the block."""
- pass
-
- def register_inverse(self):
- """Registers a matrix inverse to be computed by the block."""
- self.register_matpower(-1)
-
- @abc.abstractmethod
- def multiply_matpower(self, vector, exp):
- """Multiplies the vector by the (damped) matrix-power of the block.
-
- Args:
- vector: The vector (a Tensor or tuple of Tensors) to be multiplied.
- exp: A float representing the power to raise the block by before
- multiplying it by the vector.
-
- Returns:
- The vector left-multiplied by the (damped) matrix-power of the block.
- """
- pass
-
- def multiply_inverse(self, vector):
- """Multiplies the vector by the (damped) inverse of the block.
-
- Args:
- vector: The vector (a Tensor or tuple of Tensors) to be multiplied.
-
- Returns:
- The vector left-multiplied by the (damped) inverse of the block.
- """
- return self.multiply_matpower(vector, -1)
-
- def multiply(self, vector):
- """Multiplies the vector by the (damped) block.
-
- Args:
- vector: The vector (a Tensor or tuple of Tensors) to be multiplied.
-
- Returns:
- The vector left-multiplied by the (damped) block.
- """
- return self.multiply_matpower(vector, 1)
-
- @abc.abstractmethod
- def multiply_cholesky(self, vector, transpose=False):
- """Multiplies the vector by the (damped) Cholesky-factor of the block.
-
- Args:
- vector: The vector (a Tensor or tuple of Tensors) to be multiplied.
- transpose: Bool. If true the Cholesky factor is transposed before
- multiplying the vector. (Default: False)
-
- Returns:
- The vector left-multiplied by the (damped) Cholesky-factor of the block.
- """
- pass
-
- @abc.abstractmethod
- def multiply_cholesky_inverse(self, vector, transpose=False):
- """Multiplies vector by the (damped) inverse Cholesky-factor of the block.
-
- Args:
- vector: The vector (a Tensor or tuple of Tensors) to be multiplied.
- transpose: Bool. If true the Cholesky factor inverse is transposed
- before multiplying the vector. (Default: False)
- Returns:
- Vector left-multiplied by (damped) inverse Cholesky-factor of the block.
- """
- pass
-
- @abc.abstractmethod
- def tensors_to_compute_grads(self):
- """Returns the Tensor(s) with respect to which this FisherBlock needs grads.
- """
- pass
-
- @abc.abstractproperty
- def num_registered_towers(self):
- """Number of towers registered for this FisherBlock.
-
- Typically equal to the number of towers in a multi-tower setup.
- """
- pass
-
-
-class FullFB(FisherBlock):
- """FisherBlock using a full matrix estimate (no approximations).
-
- FullFB uses a full matrix estimate (no approximations), and should only ever
- be used for very low dimensional parameters.
-
- Note that this uses the naive "square the sum estimator", and so is applicable
- to any type of parameter in principle, but has very high variance.
- """
-
- def __init__(self, layer_collection, params):
- """Creates a FullFB block.
-
- Args:
- layer_collection: The collection of all layers in the K-FAC approximate
- Fisher information matrix to which this FisherBlock belongs.
- params: The parameters of this layer (Tensor or tuple of Tensors).
- """
- self._batch_sizes = []
- self._params = params
-
- super(FullFB, self).__init__(layer_collection)
-
- def instantiate_factors(self, grads_list, damping):
- self._damping_func = _package_func(lambda: damping, (damping,))
-
- self._factor = self._layer_collection.make_or_get_factor(
- fisher_factors.FullFactor, (grads_list, self._batch_size))
-
- def register_matpower(self, exp):
- self._factor.register_matpower(exp, self._damping_func)
-
- def register_cholesky(self):
- self._factor.register_cholesky(self._damping_func)
-
- def register_cholesky_inverse(self):
- self._factor.register_cholesky_inverse(self._damping_func)
-
- def _multiply_matrix(self, matrix, vector, transpose=False):
- vector_flat = utils.tensors_to_column(vector)
- out_flat = matrix.matmul(vector_flat, adjoint=transpose)
- return utils.column_to_tensors(vector, out_flat)
-
- def multiply_matpower(self, vector, exp):
- matrix = self._factor.get_matpower(exp, self._damping_func)
- return self._multiply_matrix(matrix, vector)
-
- def multiply_cholesky(self, vector, transpose=False):
- matrix = self._factor.get_cholesky(self._damping_func)
- return self._multiply_matrix(matrix, vector, transpose=transpose)
-
- def multiply_cholesky_inverse(self, vector, transpose=False):
- matrix = self._factor.get_cholesky_inverse(self._damping_func)
- return self._multiply_matrix(matrix, vector, transpose=transpose)
-
- def full_fisher_block(self):
- """Explicitly constructs the full Fisher block."""
- return self._factor.get_cov_as_linear_operator().to_dense()
-
- def tensors_to_compute_grads(self):
- return self._params
-
- def register_additional_tower(self, batch_size):
- """Register an additional tower.
-
- Args:
- batch_size: The batch size, used in the covariance estimator.
- """
- self._batch_sizes.append(batch_size)
-
- @property
- def num_registered_towers(self):
- return len(self._batch_sizes)
-
- @property
- def _batch_size(self):
- return math_ops.reduce_sum(self._batch_sizes)
-
-
-@six.add_metaclass(abc.ABCMeta)
-class DiagonalFB(FisherBlock):
- """A base class for FisherBlocks that use diagonal approximations."""
-
- def register_matpower(self, exp):
- # Not needed for this. Matrix powers are computed on demand in the
- # diagonal case
- pass
-
- def register_cholesky(self):
- # Not needed for this. Cholesky's are computed on demand in the
- # diagonal case
- pass
-
- def register_cholesky_inverse(self):
- # Not needed for this. Cholesky inverses's are computed on demand in the
- # diagonal case
- pass
-
- def _multiply_matrix(self, matrix, vector):
- vector_flat = utils.tensors_to_column(vector)
- out_flat = matrix.matmul(vector_flat)
- return utils.column_to_tensors(vector, out_flat)
-
- def multiply_matpower(self, vector, exp):
- matrix = self._factor.get_matpower(exp, self._damping_func)
- return self._multiply_matrix(matrix, vector)
-
- def multiply_cholesky(self, vector, transpose=False):
- matrix = self._factor.get_cholesky(self._damping_func)
- return self._multiply_matrix(matrix, vector)
-
- def multiply_cholesky_inverse(self, vector, transpose=False):
- matrix = self._factor.get_cholesky_inverse(self._damping_func)
- return self._multiply_matrix(matrix, vector)
-
- def full_fisher_block(self):
- return self._factor.get_cov_as_linear_operator().to_dense()
-
-
-class NaiveDiagonalFB(DiagonalFB):
- """FisherBlock using a diagonal matrix approximation.
-
- This type of approximation is generically applicable but quite primitive.
-
- Note that this uses the naive "square the sum estimator", and so is applicable
- to any type of parameter in principle, but has very high variance.
- """
-
- def __init__(self, layer_collection, params):
- """Creates a NaiveDiagonalFB block.
-
- Args:
- layer_collection: The collection of all layers in the K-FAC approximate
- Fisher information matrix to which this FisherBlock belongs.
- params: The parameters of this layer (Tensor or tuple of Tensors).
- """
- self._params = params
- self._batch_sizes = []
-
- super(NaiveDiagonalFB, self).__init__(layer_collection)
-
- def instantiate_factors(self, grads_list, damping):
- self._damping_func = _package_func(lambda: damping, (damping,))
-
- self._factor = self._layer_collection.make_or_get_factor(
- fisher_factors.NaiveDiagonalFactor, (grads_list, self._batch_size))
-
- def tensors_to_compute_grads(self):
- return self._params
-
- def register_additional_tower(self, batch_size):
- """Register an additional tower.
-
- Args:
- batch_size: The batch size, used in the covariance estimator.
- """
- self._batch_sizes.append(batch_size)
-
- @property
- def num_registered_towers(self):
- return len(self._batch_sizes)
-
- @property
- def _batch_size(self):
- return math_ops.reduce_sum(self._batch_sizes)
-
-
-class InputOutputMultiTower(object):
- """Mix-in class for blocks with inputs & outputs and multiple mini-batches."""
-
- def __init__(self, *args, **kwargs):
- self.__inputs = []
- self.__outputs = []
- super(InputOutputMultiTower, self).__init__(*args, **kwargs)
-
- def _process_data(self, grads_list):
- """Process data into the format used by the factors.
-
- This function takes inputs and grads_lists data and processes it into
- one of the formats expected by the FisherFactor classes (depending on
- the value of the global configuration variable TOWER_STRATEGY).
-
- The initial format of self._inputs is expected to be a list of Tensors
- over towers. Similarly grads_lists is expected to be a list over sources
- of such lists.
-
- If TOWER_STRATEGY is "concat", 'inputs' becomes a tuple containing a single
- tensor (represented as a PartitionedTensor object) equal to the
- concatenation (across towers) of all of the elements of self._inputs. And
- similarly grads_list is formatted into a tuple (over sources) of such
- tensors (also represented as PartitionedTensors).
-
- If TOWER_STRATEGY is "separate", formatting of inputs and grads_list
- remains unchanged from the initial format (although possibly converting
- from lists into tuples).
-
- Args:
- grads_list: grads_list in its initial format (see above).
-
- Returns:
- inputs: self._inputs transformed into the appropriate format (see
- above).
- grads_list: grads_list transformed into the appropriate format (see
- above).
-
- Raises:
- ValueError: if TOWER_STRATEGY is not one of "separate" or "concat".
- """
- inputs = self._inputs
- # inputs is a list over towers of Tensors
- # grads_list is a list of list with the first index being sources and the
- # second being towers.
- if fisher_factors.TOWER_STRATEGY == "concat":
- # Merge towers together into a PartitionedTensor. We package it in
- # a singleton tuple since the factors will expect a list over towers
- inputs = (utils.PartitionedTensor(inputs),)
- # Do the same for grads_list but preserve leading sources dimension
- grads_list = tuple((utils.PartitionedTensor(grads),)
- for grads in grads_list)
- elif fisher_factors.TOWER_STRATEGY == "separate":
- inputs = tuple(inputs)
- grads_list = tuple(grads_list)
-
- else:
- raise ValueError("Global config variable TOWER_STRATEGY must be one of "
- "'concat' or 'separate'.")
-
- return inputs, grads_list
-
- def tensors_to_compute_grads(self):
- """Tensors to compute derivative of loss with respect to."""
- return tuple(self._outputs)
-
- def register_additional_tower(self, inputs, outputs):
- self._inputs.append(inputs)
- self._outputs.append(outputs)
-
- @property
- def num_registered_towers(self):
- result = len(self._inputs)
- assert result == len(self._outputs)
- return result
-
- @property
- def _inputs(self):
- return self.__inputs
-
- @property
- def _outputs(self):
- return self.__outputs
-
-
-class FullyConnectedDiagonalFB(InputOutputMultiTower, DiagonalFB):
- """FisherBlock for fully-connected (dense) layers using a diagonal approx.
-
- Estimates the Fisher Information matrix's diagonal entries for a fully
- connected layer. Unlike NaiveDiagonalFB this uses the low-variance "sum of
- squares" estimator.
-
- Let 'params' be a vector parameterizing a model and 'i' an arbitrary index
- into it. We are interested in Fisher(params)[i, i]. This is,
-
- $$Fisher(params)[i, i] = E[ v(x, y, params) v(x, y, params)^T ][i, i]
- = E[ v(x, y, params)[i] ^ 2 ]$$
-
- Consider fully connected layer in this model with (unshared) weight matrix
- 'w'. For an example 'x' that produces layer inputs 'a' and output
- preactivations 's',
-
- $$v(x, y, w) = vec( a (d loss / d s)^T )$$
-
- This FisherBlock tracks Fisher(params)[i, i] for all indices 'i' corresponding
- to the layer's parameters 'w'.
- """
-
- def __init__(self, layer_collection, has_bias=False):
- """Creates a FullyConnectedDiagonalFB block.
-
- Args:
- layer_collection: The collection of all layers in the K-FAC approximate
- Fisher information matrix to which this FisherBlock belongs.
- has_bias: Whether the component Kronecker factors have an additive bias.
- (Default: False)
- """
- self._has_bias = has_bias
-
- super(FullyConnectedDiagonalFB, self).__init__(layer_collection)
-
- def instantiate_factors(self, grads_list, damping):
- inputs, grads_list = self._process_data(grads_list)
-
- self._factor = self._layer_collection.make_or_get_factor(
- fisher_factors.FullyConnectedDiagonalFactor,
- (inputs, grads_list, self._has_bias))
-
- self._damping_func = _package_func(lambda: damping, (damping,))
-
-
-class ConvDiagonalFB(InputOutputMultiTower, DiagonalFB):
- """FisherBlock for 2-D convolutional layers using a diagonal approx.
-
- Estimates the Fisher Information matrix's diagonal entries for a convolutional
- layer. Unlike NaiveDiagonalFB this uses the low-variance "sum of squares"
- estimator.
-
- Let 'params' be a vector parameterizing a model and 'i' an arbitrary index
- into it. We are interested in Fisher(params)[i, i]. This is,
-
- $$Fisher(params)[i, i] = E[ v(x, y, params) v(x, y, params)^T ][i, i]
- = E[ v(x, y, params)[i] ^ 2 ]$$
-
- Consider a convoluational layer in this model with (unshared) filter matrix
- 'w'. For an example image 'x' that produces layer inputs 'a' and output
- preactivations 's',
-
- $$v(x, y, w) = vec( sum_{loc} a_{loc} (d loss / d s_{loc})^T )$$
-
- where 'loc' is a single (x, y) location in an image.
-
- This FisherBlock tracks Fisher(params)[i, i] for all indices 'i' corresponding
- to the layer's parameters 'w'.
- """
-
- def __init__(self,
- layer_collection,
- params,
- strides,
- padding,
- data_format=None,
- dilations=None):
- """Creates a ConvDiagonalFB block.
-
- Args:
- layer_collection: The collection of all layers in the K-FAC approximate
- Fisher information matrix to which this FisherBlock belongs.
- params: The parameters (Tensor or tuple of Tensors) of this layer. If
- kernel alone, a Tensor of shape [kernel_height, kernel_width,
- in_channels, out_channels]. If kernel and bias, a tuple of 2 elements
- containing the previous and a Tensor of shape [out_channels].
- strides: The stride size in this layer (1-D Tensor of length 4).
- padding: The padding in this layer (e.g. "SAME").
- data_format: str or None. Format of input data.
- dilations: List of 4 ints or None. Rate for dilation along all dimensions.
-
- Raises:
- ValueError: if strides is not length-4.
- ValueError: if dilations is not length-4.
- ValueError: if channel is not last dimension.
- """
- if len(strides) != 4:
- raise ValueError("strides must contain 4 numbers.")
-
- if dilations is None:
- dilations = [1, 1, 1, 1]
-
- if len(dilations) != 4:
- raise ValueError("dilations must contain 4 numbers.")
-
- if not utils.is_data_format_channel_last(data_format):
- raise ValueError("data_format must be channels-last.")
-
- self._strides = maybe_tuple(strides)
- self._padding = padding
- self._data_format = data_format
- self._dilations = maybe_tuple(dilations)
- self._has_bias = isinstance(params, (tuple, list))
-
- fltr = params[0] if self._has_bias else params
- self._filter_shape = tuple(fltr.shape.as_list())
-
- if len(self._filter_shape) != 4:
- raise ValueError(
- "Convolution filter must be of shape"
- " [filter_height, filter_width, in_channels, out_channels].")
-
- super(ConvDiagonalFB, self).__init__(layer_collection)
-
- def instantiate_factors(self, grads_list, damping):
- inputs, grads_list = self._process_data(grads_list)
-
- # Infer number of locations upon which convolution is applied.
- self._num_locations = num_conv_locations(inputs[0].shape.as_list(),
- self._strides)
-
- self._factor = self._layer_collection.make_or_get_factor(
- fisher_factors.ConvDiagonalFactor,
- (inputs, grads_list, self._filter_shape, self._strides, self._padding,
- self._data_format, self._dilations, self._has_bias))
-
- def damping_func():
- return self._num_locations * normalize_damping(damping,
- self._num_locations)
-
- damping_id = (self._num_locations, "mult", "normalize_damping", damping,
- self._num_locations)
- self._damping_func = _package_func(damping_func, damping_id)
-
-
-class KroneckerProductFB(FisherBlock):
- """A base class for blocks with separate input and output Kronecker factors.
-
- The Fisher block is approximated as a Kronecker product of the input and
- output factors.
- """
-
- def _setup_damping(self, damping, normalization=None):
- """Makes functions that compute the damping values for both factors."""
- def compute_damping():
- if normalization is not None:
- maybe_normalized_damping = normalize_damping(damping, normalization)
- else:
- maybe_normalized_damping = damping
-
- return compute_pi_adjusted_damping(
- self._input_factor.get_cov_as_linear_operator(),
- self._output_factor.get_cov_as_linear_operator(),
- maybe_normalized_damping**0.5)
-
- if normalization is not None:
- damping_id = ("compute_pi_adjusted_damping",
- "cov", self._input_factor.name,
- "cov", self._output_factor.name,
- "normalize_damping", damping, normalization, "power", 0.5)
- else:
- damping_id = ("compute_pi_adjusted_damping",
- "cov", self._input_factor.name,
- "cov", self._output_factor.name,
- damping, "power", 0.5)
-
- self._input_damping_func = _package_func(lambda: compute_damping()[0],
- damping_id + ("ref", 0))
- self._output_damping_func = _package_func(lambda: compute_damping()[1],
- damping_id + ("ref", 1))
-
- def register_matpower(self, exp):
- self._input_factor.register_matpower(exp, self._input_damping_func)
- self._output_factor.register_matpower(exp, self._output_damping_func)
-
- def register_cholesky(self):
- self._input_factor.register_cholesky(self._input_damping_func)
- self._output_factor.register_cholesky(self._output_damping_func)
-
- def register_cholesky_inverse(self):
- self._input_factor.register_cholesky_inverse(self._input_damping_func)
- self._output_factor.register_cholesky_inverse(self._output_damping_func)
-
- @property
- def _renorm_coeff(self):
- """Kronecker factor multiplier coefficient.
-
- If this FisherBlock is represented as 'FB = c * kron(left, right)', then
- this is 'c'.
-
- Returns:
- 0-D Tensor.
- """
- return 1.0
-
- def _multiply_factored_matrix(self, left_factor, right_factor, vector,
- extra_scale=1.0, transpose_left=False,
- transpose_right=False):
- reshaped_vector = utils.layer_params_to_mat2d(vector)
- reshaped_out = right_factor.matmul_right(reshaped_vector,
- adjoint=transpose_right)
- reshaped_out = left_factor.matmul(reshaped_out,
- adjoint=transpose_left)
- if extra_scale != 1.0:
- reshaped_out *= math_ops.cast(extra_scale, dtype=reshaped_out.dtype)
- return utils.mat2d_to_layer_params(vector, reshaped_out)
-
- def multiply_matpower(self, vector, exp):
- left_factor = self._input_factor.get_matpower(
- exp, self._input_damping_func)
- right_factor = self._output_factor.get_matpower(
- exp, self._output_damping_func)
- extra_scale = float(self._renorm_coeff)**exp
- return self._multiply_factored_matrix(left_factor, right_factor, vector,
- extra_scale=extra_scale)
-
- def multiply_cholesky(self, vector, transpose=False):
- left_factor = self._input_factor.get_cholesky(self._input_damping_func)
- right_factor = self._output_factor.get_cholesky(self._output_damping_func)
- extra_scale = float(self._renorm_coeff)**0.5
- return self._multiply_factored_matrix(left_factor, right_factor, vector,
- extra_scale=extra_scale,
- transpose_left=transpose,
- transpose_right=not transpose)
-
- def multiply_cholesky_inverse(self, vector, transpose=False):
- left_factor = self._input_factor.get_cholesky_inverse(
- self._input_damping_func)
- right_factor = self._output_factor.get_cholesky_inverse(
- self._output_damping_func)
- extra_scale = float(self._renorm_coeff)**-0.5
- return self._multiply_factored_matrix(left_factor, right_factor, vector,
- extra_scale=extra_scale,
- transpose_left=transpose,
- transpose_right=not transpose)
-
- def full_fisher_block(self):
- """Explicitly constructs the full Fisher block.
-
- Used for testing purposes. (In general, the result may be very large.)
-
- Returns:
- The full Fisher block.
- """
- left_factor = self._input_factor.get_cov_as_linear_operator().to_dense()
- right_factor = self._output_factor.get_cov_as_linear_operator().to_dense()
- return self._renorm_coeff * utils.kronecker_product(left_factor,
- right_factor)
-
-
-class EmbeddingKFACFB(InputOutputMultiTower, KroneckerProductFB):
- """K-FAC FisherBlock for embedding layers.
-
- This FisherBlock is similar to FullyConnectedKFACBasicFB, except that its
- input factor is approximated by a diagonal matrix. In the case that each
- example references exactly one embedding, this approximation is exact.
-
- Does not support bias parameters.
- """
-
- def __init__(self, layer_collection, vocab_size):
- """Creates a EmbeddingKFACFB block.
-
- Args:
- layer_collection: The collection of all layers in the K-FAC approximate
- Fisher information matrix to which this FisherBlock belongs.
- vocab_size: int. Size of vocabulary for this embedding layer.
- """
- self._vocab_size = vocab_size
-
- super(EmbeddingKFACFB, self).__init__(layer_collection)
-
- def instantiate_factors(self, grads_list, damping):
- """Instantiate Kronecker Factors for this FisherBlock.
-
- Args:
- grads_list: List of list of Tensors. grads_list[i][j] is the
- gradient of the loss with respect to 'outputs' from source 'i' and
- tower 'j'. Each Tensor has shape [tower_minibatch_size, output_size].
- damping: 0-D Tensor or float. 'damping' * identity is approximately added
- to this FisherBlock's Fisher approximation.
- """
- inputs, grads_list = self._process_data(grads_list)
-
- self._input_factor = self._layer_collection.make_or_get_factor(
- fisher_factors.EmbeddingInputKroneckerFactor,
- (inputs, self._vocab_size))
- self._output_factor = self._layer_collection.make_or_get_factor(
- fisher_factors.FullyConnectedKroneckerFactor, (grads_list,))
- self._setup_damping(damping)
-
-
-class FullyConnectedKFACBasicFB(InputOutputMultiTower, KroneckerProductFB):
- """K-FAC FisherBlock for fully-connected (dense) layers.
-
- This uses the Kronecker-factorized approximation from the original
- K-FAC paper (https://arxiv.org/abs/1503.05671)
- """
-
- def __init__(self, layer_collection, has_bias=False):
- """Creates a FullyConnectedKFACBasicFB block.
-
- Args:
- layer_collection: The collection of all layers in the K-FAC approximate
- Fisher information matrix to which this FisherBlock belongs.
- has_bias: Whether the component Kronecker factors have an additive bias.
- (Default: False)
- """
- self._has_bias = has_bias
-
- super(FullyConnectedKFACBasicFB, self).__init__(layer_collection)
-
- def instantiate_factors(self, grads_list, damping):
- """Instantiate Kronecker Factors for this FisherBlock.
-
- Args:
- grads_list: List of list of Tensors. grads_list[i][j] is the
- gradient of the loss with respect to 'outputs' from source 'i' and
- tower 'j'. Each Tensor has shape [tower_minibatch_size, output_size].
- damping: 0-D Tensor or float. 'damping' * identity is approximately added
- to this FisherBlock's Fisher approximation.
- """
- inputs, grads_list = self._process_data(grads_list)
-
- self._input_factor = self._layer_collection.make_or_get_factor(
- fisher_factors.FullyConnectedKroneckerFactor,
- ((inputs,), self._has_bias))
- self._output_factor = self._layer_collection.make_or_get_factor(
- fisher_factors.FullyConnectedKroneckerFactor,
- (grads_list,))
- self._setup_damping(damping)
-
-
-class ConvKFCBasicFB(InputOutputMultiTower, KroneckerProductFB):
- r"""FisherBlock for convolutional layers using the basic KFC approx.
-
- Estimates the Fisher Information matrix's blog for a convolutional
- layer.
-
- Consider a convolutional layer in this model with (unshared) filter matrix
- 'w'. For a minibatch that produces inputs 'a' and output preactivations 's',
- this FisherBlock estimates,
-
- $$F(w) = \#locations * kronecker(E[flat(a) flat(a)^T],
- E[flat(ds) flat(ds)^T])$$
-
- where
-
- $$ds = (d / ds) log p(y | x, w)$$
- #locations = number of (x, y) locations where 'w' is applied.
-
- where the expectation is taken over all examples and locations and flat()
- concatenates an array's leading dimensions.
-
- See equation 23 in https://arxiv.org/abs/1602.01407 for details.
- """
-
- def __init__(self,
- layer_collection,
- params,
- padding,
- strides=None,
- dilation_rate=None,
- data_format=None,
- extract_patches_fn=None):
- """Creates a ConvKFCBasicFB block.
-
- Args:
- layer_collection: The collection of all layers in the K-FAC approximate
- Fisher information matrix to which this FisherBlock belongs.
- params: The parameters (Tensor or tuple of Tensors) of this layer. If
- kernel alone, a Tensor of shape [..spatial_filter_shape..,
- in_channels, out_channels]. If kernel and bias, a tuple of 2 elements
- containing the previous and a Tensor of shape [out_channels].
- padding: str. Padding method.
- strides: List of ints or None. Contains [..spatial_filter_strides..] if
- 'extract_patches_fn' is compatible with tf.nn.convolution(), else
- [1, ..spatial_filter_strides, 1].
- dilation_rate: List of ints or None. Rate for dilation along each spatial
- dimension if 'extract_patches_fn' is compatible with
- tf.nn.convolution(), else [1, ..spatial_dilation_rates.., 1].
- data_format: str or None. Format of input data.
- extract_patches_fn: str or None. Name of function that extracts image
- patches. One of "extract_convolution_patches", "extract_image_patches",
- "extract_pointwise_conv2d_patches".
- """
- self._padding = padding
- self._strides = maybe_tuple(strides)
- self._dilation_rate = maybe_tuple(dilation_rate)
- self._data_format = data_format
- self._extract_patches_fn = extract_patches_fn
- self._has_bias = isinstance(params, (tuple, list))
-
- fltr = params[0] if self._has_bias else params
- self._filter_shape = tuple(fltr.shape.as_list())
-
- super(ConvKFCBasicFB, self).__init__(layer_collection)
-
- def instantiate_factors(self, grads_list, damping):
- inputs, grads_list = self._process_data(grads_list)
-
- # Infer number of locations upon which convolution is applied.
- self._num_locations = num_conv_locations(inputs[0].shape.as_list(),
- self._strides)
-
- self._input_factor = self._layer_collection.make_or_get_factor(
- fisher_factors.ConvInputKroneckerFactor,
- (inputs, self._filter_shape, self._padding, self._strides,
- self._dilation_rate, self._data_format, self._extract_patches_fn,
- self._has_bias))
- self._output_factor = self._layer_collection.make_or_get_factor(
- fisher_factors.ConvOutputKroneckerFactor, (grads_list,))
-
- self._setup_damping(damping, normalization=self._num_locations)
-
- @property
- def _renorm_coeff(self):
- return self._num_locations
-
-
-class DepthwiseConvDiagonalFB(ConvDiagonalFB):
- """FisherBlock for depthwise_conv2d().
-
- Equivalent to ConvDiagonalFB applied to each input channel in isolation.
- """
-
- def __init__(self,
- layer_collection,
- params,
- strides,
- padding,
- rate=None,
- data_format=None):
- """Creates a DepthwiseConvKFCBasicFB block.
-
- Args:
- layer_collection: The collection of all layers in the K-FAC approximate
- Fisher information matrix to which this FisherBlock belongs.
- params: Tensor of shape [filter_height, filter_width, in_channels,
- channel_multiplier].
- strides: List of 4 ints. Strides along all dimensions.
- padding: str. Padding method.
- rate: List of 4 ints or None. Rate for dilation along all dimensions.
- data_format: str or None. Format of input data.
-
- Raises:
- NotImplementedError: If parameters contains bias.
- ValueError: If filter is not 4-D.
- ValueError: If strides is not length-4.
- ValueError: If rates is not length-2.
- ValueError: If channels are not last dimension.
- """
- if isinstance(params, (tuple, list)):
- raise NotImplementedError("Bias not yet supported.")
-
- if params.shape.ndims != 4:
- raise ValueError("Filter must be 4-D.")
-
- if len(strides) != 4:
- raise ValueError("strides must account for 4 dimensions.")
-
- if rate is not None:
- if len(rate) != 2:
- raise ValueError("rate must only account for spatial dimensions.")
- rate = [1, rate[0], rate[1], 1] # conv2d expects 4-element rate.
-
- if not utils.is_data_format_channel_last(data_format):
- raise ValueError("data_format must be channels-last.")
-
- super(DepthwiseConvDiagonalFB, self).__init__(
- layer_collection=layer_collection,
- params=params,
- strides=strides,
- padding=padding,
- dilations=rate,
- data_format=data_format)
-
- # This is a hack to overwrite the same setting in ConvKFCBasicFB.__init__().
- filter_height, filter_width, in_channels, channel_multiplier = (
- params.shape.as_list())
- self._filter_shape = (filter_height, filter_width, in_channels,
- in_channels * channel_multiplier)
-
- def _multiply_matrix(self, matrix, vector):
- conv2d_vector = depthwise_conv2d_filter_to_conv2d_filter(vector)
- conv2d_result = super(
- DepthwiseConvDiagonalFB, self)._multiply_matrix(matrix, conv2d_vector)
- return conv2d_filter_to_depthwise_conv2d_filter(conv2d_result)
-
-
-class DepthwiseConvKFCBasicFB(ConvKFCBasicFB):
- """FisherBlock for depthwise_conv2d().
-
- Equivalent to ConvKFCBasicFB applied to each input channel in isolation.
- """
-
- def __init__(self,
- layer_collection,
- params,
- strides,
- padding,
- rate=None,
- data_format=None):
- """Creates a DepthwiseConvKFCBasicFB block.
-
- Args:
- layer_collection: The collection of all layers in the K-FAC approximate
- Fisher information matrix to which this FisherBlock belongs.
- params: Tensor of shape [filter_height, filter_width, in_channels,
- channel_multiplier].
- strides: List of 4 ints. Strides along all dimensions.
- padding: str. Padding method.
- rate: List of 4 ints or None. Rate for dilation along all dimensions.
- data_format: str or None. Format of input data.
-
- Raises:
- NotImplementedError: If parameters contains bias.
- ValueError: If filter is not 4-D.
- ValueError: If strides is not length-4.
- ValueError: If rates is not length-2.
- ValueError: If channels are not last dimension.
- """
- if isinstance(params, (tuple, list)):
- raise NotImplementedError("Bias not yet supported.")
-
- if params.shape.ndims != 4:
- raise ValueError("Filter must be 4-D.")
-
- if len(strides) != 4:
- raise ValueError("strides must account for 4 dimensions.")
-
- if rate is not None:
- if len(rate) != 2:
- raise ValueError("rate must only account for spatial dimensions.")
- rate = [1, rate[0], rate[1], 1] # conv2d expects 4-element rate.
-
- if not utils.is_data_format_channel_last(data_format):
- raise ValueError("data_format must be channels-last.")
-
- super(DepthwiseConvKFCBasicFB, self).__init__(
- layer_collection=layer_collection,
- params=params,
- padding=padding,
- strides=strides,
- dilation_rate=rate,
- data_format=data_format,
- extract_patches_fn="extract_image_patches")
-
- # This is a hack to overwrite the same setting in ConvKFCBasicFB.__init__().
- filter_height, filter_width, in_channels, channel_multiplier = (
- params.shape.as_list())
- self._filter_shape = (filter_height, filter_width, in_channels,
- in_channels * channel_multiplier)
-
- def _multiply_factored_matrix(self, left_factor, right_factor, vector,
- extra_scale=1.0, transpose_left=False,
- transpose_right=False):
- conv2d_vector = depthwise_conv2d_filter_to_conv2d_filter(vector)
- conv2d_result = super(
- DepthwiseConvKFCBasicFB, self)._multiply_factored_matrix(
- left_factor, right_factor, conv2d_vector, extra_scale=extra_scale,
- transpose_left=transpose_left, transpose_right=transpose_right)
- return conv2d_filter_to_depthwise_conv2d_filter(conv2d_result)
-
-
-def depthwise_conv2d_filter_to_conv2d_filter(filter, name=None): # pylint: disable=redefined-builtin
- """Converts a convolution filter for use with conv2d.
-
- Transforms a filter for use with tf.nn.depthwise_conv2d() to one that's
- compatible with tf.nn.conv2d().
-
- Args:
- filter: Tensor of shape [height, width, in_channels, channel_multiplier].
- name: None or str. Name of Op.
-
- Returns:
- Tensor of shape [height, width, in_channels, out_channels].
-
- """
- with ops.name_scope(name, "depthwise_conv2d_filter_to_conv2d_filter",
- [filter]):
- filter = ops.convert_to_tensor(filter)
- filter_height, filter_width, in_channels, channel_multiplier = (
- filter.shape.as_list())
-
- results = []
- for i in range(in_channels):
- # Slice out one in_channel's filter. Insert zeros around it to force it
- # to affect that channel and that channel alone.
- elements = []
- if i > 0:
- elements.append(
- array_ops.zeros(
- [filter_height, filter_width, i, channel_multiplier]))
- elements.append(filter[:, :, i:(i + 1), :])
- if i + 1 < in_channels:
- elements.append(
- array_ops.zeros([
- filter_height, filter_width, in_channels - (i + 1),
- channel_multiplier
- ]))
-
- # Concat along in_channel.
- results.append(
- array_ops.concat(elements, axis=-2, name="in_channel_%d" % i))
-
- # Concat along out_channel.
- return array_ops.concat(results, axis=-1, name="out_channel")
-
-
-def conv2d_filter_to_depthwise_conv2d_filter(filter, name=None): # pylint: disable=redefined-builtin
- """Converts a convolution filter for use with depthwise_conv2d.
-
- Transforms a filter for use with tf.nn.conv2d() to one that's
- compatible with tf.nn.depthwise_conv2d(). Ignores all filters but those along
- the diagonal.
-
- Args:
- filter: Tensor of shape [height, width, in_channels, out_channels].
- name: None or str. Name of Op.
-
- Returns:
- Tensor of shape,
- [height, width, in_channels, channel_multiplier]
-
- Raises:
- ValueError: if out_channels is not evenly divisible by in_channels.
- """
- with ops.name_scope(name, "conv2d_filter_to_depthwise_conv2d_filter",
- [filter]):
- filter = ops.convert_to_tensor(filter)
- filter_height, filter_width, in_channels, out_channels = (
- filter.shape.as_list())
-
- if out_channels % in_channels != 0:
- raise ValueError("out_channels must be evenly divisible by in_channels.")
- channel_multiplier = out_channels // in_channels
-
- results = []
- filter = array_ops.reshape(filter, [
- filter_height, filter_width, in_channels, in_channels,
- channel_multiplier
- ])
- for i in range(in_channels):
- # Slice out output corresponding to the correct filter.
- filter_slice = array_ops.reshape(
- filter[:, :, i, i, :],
- [filter_height, filter_width, 1, channel_multiplier])
- results.append(filter_slice)
-
- # Concat along out_channel.
- return array_ops.concat(results, axis=-2, name="in_channels")
-
-
-def maybe_tuple(obj):
- if not isinstance(obj, list):
- return obj
- return tuple(obj)
-
-
-def num_conv_locations(input_shape, strides):
- """Returns the number of spatial locations a 2D Conv kernel is applied to.
-
- Args:
- input_shape: List of ints representing shape of inputs to
- tf.nn.convolution().
- strides: List of ints representing strides along spatial dimensions as
- passed in to tf.nn.convolution().
-
- Returns:
- A scalar |T| denoting the number of spatial locations for the Conv layer.
- """
- spatial_input_locations = np.prod(input_shape[1:-1])
-
- if strides is None:
- spatial_strides_divisor = 1
- else:
- spatial_strides_divisor = np.prod(strides)
-
- return spatial_input_locations // spatial_strides_divisor
-
-
-class InputOutputMultiTowerMultiUse(InputOutputMultiTower):
- """Adds methods for multi-use/time-step case to InputOutputMultiTower."""
-
- def __init__(self, num_uses=None, *args, **kwargs):
- self._num_uses = num_uses
- super(InputOutputMultiTowerMultiUse, self).__init__(*args, **kwargs)
-
- def _process_data(self, grads_list):
- """Process temporal/multi-use data into the format used by the factors.
-
- This function takes inputs and grads_lists data and processes it into
- one of the formats expected by the FisherFactor classes (depending on
- the value of the global configuration variable TOWER_STRATEGY).
-
- It accepts the data in one of two initial formats. The first possible
- format is where self._inputs is a list of list of Tensors. The first index
- is tower, the second is use/time-step. grads_list, meanwhile, is a list
- over sources of such lists of lists.
-
- The second possible data format is where self._inputs is a Tensor with
- uses/times-steps folded into the batch dimension. i.e. it is a Tensor
- of shape [num_uses * size_batch, ...] which represents a reshape of a
- Tensor of shape [num_uses, size_batch, ...]. And similarly grads_list is
- a list over sources of such Tensors.
-
- There are two possible formats which inputs and grads_list are transformed
- into.
-
- If TOWER_STRATEGY is "concat", 'inputs' becomes a tuple containing
- a single tensor (represented as a PartitionedTensor object) with all of
- the data from the towers, as well as the uses/time-steps, concatenated
- together. In this tensor the leading dimension is the batch and
- use/time-step dimensions folded together (with 'use' being the major of
- these two, so that the tensors can be thought of as reshapes of ones of
- shape [num_uses, batch_size, ...]). grads_list is similarly formatted as a
- tuple over sources of such tensors.
-
- If TOWER_STRATEGY is "separate" the inputs are formatted into lists of
- tensors over towers. Each of these tensors has a similar format to
- the tensor produced by the "concat" option, except that each contains
- only the data from a single tower. grads_list is similarly formatted
- into a tuple over sources of such tuples.
-
- Args:
- grads_list: grads_list in its initial format (see above).
-
- Returns:
- inputs: self._inputs transformed into the appropriate format (see
- above).
- grads_list: grads_list transformed into the appropriate format (see
- above).
-
- Raises:
- ValueError: If TOWER_STRATEGY is not one of "separate" or "concat".
- ValueError: If the given/initial format of self._inputs and grads_list
- isn't recognized, or doesn't agree with self._num_uses.
- """
-
- inputs = self._inputs
-
- if isinstance(inputs[0], (list, tuple)):
- num_uses = len(inputs[0])
- if self._num_uses is not None and self._num_uses != num_uses:
- raise ValueError("num_uses argument doesn't match length of inputs.")
- else:
- self._num_uses = num_uses
-
- # Check that all mini-batches/towers have the same number of uses
- if not all(len(input_) == num_uses for input_ in inputs):
- raise ValueError("Length of inputs argument is inconsistent across "
- "towers.")
-
- if fisher_factors.TOWER_STRATEGY == "concat":
- # Reverse the tower and use/time-step indices, so that use is now first,
- # and towers is second
- inputs = tuple(zip(*inputs))
-
- # Flatten the two dimensions
- inputs = nest.flatten(inputs)
-
- # Merge everything together into a PartitionedTensor. We package it in
- # a singleton tuple since the factors will expect a list over towers
- inputs = (utils.PartitionedTensor(inputs),)
-
- elif fisher_factors.TOWER_STRATEGY == "separate":
- # Merge together the uses/time-step dimension into PartitionedTensors,
- # but keep the leading dimension (towers) intact for the factors to
- # process individually.
- inputs = tuple(utils.PartitionedTensor(input_) for input_ in inputs)
-
- else:
- raise ValueError("Global config variable TOWER_STRATEGY must be one of "
- "'concat' or 'separate'.")
- else:
- inputs = tuple(inputs)
-
- # Now we perform the analogous processing for grads_list
- if isinstance(grads_list[0][0], (list, tuple)):
- num_uses = len(grads_list[0][0])
- if self._num_uses is not None and self._num_uses != num_uses:
- raise ValueError("num_uses argument doesn't match length of outputs, "
- "or length of outputs is inconsistent with length of "
- "inputs.")
- else:
- self._num_uses = num_uses
-
- if not all(len(grad) == num_uses for grads in grads_list
- for grad in grads):
- raise ValueError("Length of outputs argument is inconsistent across "
- "towers.")
-
- if fisher_factors.TOWER_STRATEGY == "concat":
- # Reverse the tower and use/time-step indices, so that use is now first,
- # and towers is second
- grads_list = tuple(tuple(zip(*grads)) for grads in grads_list)
-
- # Flatten the two dimensions, leaving the leading dimension (source)
- # intact
- grads_list = tuple(nest.flatten(grads) for grads in grads_list)
-
- # Merge inner dimensions together into PartitionedTensors. We package
- # them in a singleton tuple since the factors will expect a list over
- # towers
- grads_list = tuple((utils.PartitionedTensor(grads),)
- for grads in grads_list)
-
- elif fisher_factors.TOWER_STRATEGY == "separate":
- # Merge together the uses/time-step dimension into PartitionedTensors,
- # but keep the leading dimension (towers) intact for the factors to
- # process individually.
- grads_list = tuple(tuple(utils.PartitionedTensor(grad)
- for grad in grads)
- for grads in grads_list)
-
- else:
- raise ValueError("Global config variable TOWER_STRATEGY must be one of "
- "'concat' or 'separate'.")
- else:
- grads_list = tuple(tuple(grads) for grads in grads_list)
-
- if self._num_uses is None:
- raise ValueError("You must supply a value for the num_uses argument if "
- "the number of uses cannot be inferred from inputs or "
- "outputs arguments (e.g. if they are both given in the "
- "single Tensor format, instead of as lists of Tensors.")
-
- return inputs, grads_list
-
-
-class FullyConnectedMultiIndepFB(InputOutputMultiTowerMultiUse,
- KroneckerProductFB):
- """FisherBlock for fully-connected layers that share parameters.
-
- This class implements the "independence across time" approximation from the
- following paper:
- https://openreview.net/pdf?id=HyMTkQZAb
- """
-
- def __init__(self, layer_collection, has_bias=False, num_uses=None):
- """Creates a FullyConnectedMultiIndepFB block.
-
- Args:
- layer_collection: LayerCollection instance.
- has_bias: bool. If True, estimates Fisher with respect to a bias
- parameter as well as the layer's parameters.
- num_uses: int or None. Number of uses of the layer in the model's graph.
- Only required if the data is formatted with uses/time folded into the
- batch dimension (instead of uses/time being a list dimension).
- (Default: None)
- """
- self._has_bias = has_bias
-
- super(FullyConnectedMultiIndepFB, self).__init__(
- layer_collection=layer_collection,
- num_uses=num_uses)
-
- def instantiate_factors(self, grads_list, damping):
- inputs, grads_list = self._process_data(grads_list)
-
- self._input_factor = self._layer_collection.make_or_get_factor(
- fisher_factors.FullyConnectedMultiKF,
- ((inputs,), self._num_uses, self._has_bias))
-
- self._output_factor = self._layer_collection.make_or_get_factor(
- fisher_factors.FullyConnectedMultiKF, (grads_list, self._num_uses))
-
- self._setup_damping(damping, normalization=self._num_uses)
-
- @property
- def _renorm_coeff(self):
- return float(self._num_uses)
-
-
-class ConvKFCBasicMultiIndepFB(InputOutputMultiTowerMultiUse,
- KroneckerProductFB):
- """FisherBlock for 2D convolutional layers using the basic KFC approx.
-
- Similar to ConvKFCBasicFB except that this version supports multiple
- uses/time-steps via a standard independence approximation. Similar to the
- "independence across time" used in FullyConnectedMultiIndepFB but generalized
- in the obvious way to conv layers.
- """
-
- def __init__(self,
- layer_collection,
- params,
- padding,
- strides=None,
- dilation_rate=None,
- data_format=None,
- extract_patches_fn=None,
- num_uses=None):
- """Creates a ConvKFCBasicMultiIndepFB block.
-
- Args:
- layer_collection: The collection of all layers in the K-FAC approximate
- Fisher information matrix to which this FisherBlock belongs.
- params: The parameters (Tensor or tuple of Tensors) of this layer. If
- kernel alone, a Tensor of shape [..spatial_filter_shape..,
- in_channels, out_channels]. If kernel and bias, a tuple of 2 elements
- containing the previous and a Tensor of shape [out_channels].
- padding: str. Padding method.
- strides: List of ints or None. Contains [..spatial_filter_strides..] if
- 'extract_patches_fn' is compatible with tf.nn.convolution(), else
- [1, ..spatial_filter_strides, 1].
- dilation_rate: List of ints or None. Rate for dilation along each spatial
- dimension if 'extract_patches_fn' is compatible with
- tf.nn.convolution(), else [1, ..spatial_dilation_rates.., 1].
- data_format: str or None. Format of input data.
- extract_patches_fn: str or None. Name of function that extracts image
- patches. One of "extract_convolution_patches", "extract_image_patches",
- "extract_pointwise_conv2d_patches".
- num_uses: int or None. Number of uses of the layer in the model's graph.
- Only required if the data is formatted with uses/time folded into the
- batch dimension (instead of uses/time being a list dimension).
- (Default: None)
- """
- self._padding = padding
- self._strides = maybe_tuple(strides)
- self._dilation_rate = maybe_tuple(dilation_rate)
- self._data_format = data_format
- self._extract_patches_fn = extract_patches_fn
- self._has_bias = isinstance(params, (tuple, list))
-
- fltr = params[0] if self._has_bias else params
- self._filter_shape = tuple(fltr.shape.as_list())
-
- super(ConvKFCBasicMultiIndepFB, self).__init__(
- layer_collection=layer_collection,
- num_uses=num_uses)
-
- def instantiate_factors(self, grads_list, damping):
- inputs, grads_list = self._process_data(grads_list)
-
- # Infer number of locations upon which convolution is applied.
- self._num_locations = num_conv_locations(inputs[0].shape.as_list(),
- self._strides)
-
- self._input_factor = self._layer_collection.make_or_get_factor(
- fisher_factors.ConvInputKroneckerFactor,
- (inputs, self._filter_shape, self._padding, self._strides,
- self._dilation_rate, self._data_format, self._extract_patches_fn,
- self._has_bias))
- self._output_factor = self._layer_collection.make_or_get_factor(
- fisher_factors.ConvOutputKroneckerFactor, (grads_list,))
-
- self._setup_damping(damping, normalization=
- (self._num_locations * self._num_uses))
-
- @property
- def _renorm_coeff(self):
- return self._num_locations * self._num_uses
-
-
-class EmbeddingKFACMultiIndepFB(InputOutputMultiTowerMultiUse,
- KroneckerProductFB):
- """K-FAC FisherBlock for embedding layers used multiple times in the graph.
-
- Similar to EmbeddingKFACFB except that this version supports multiple uses
- of the parameter within a single model. These uses could correspond to time
- steps in an RNN architecture, but they don't have to.
-
- Does not support bias parameters.
- """
-
- def __init__(self, layer_collection, vocab_size, num_uses=None):
- """Creates a EmbeddingKFACMultiIndepFB block.
-
- Args:
- layer_collection: The collection of all layers in the K-FAC approximate
- Fisher information matrix to which this FisherBlock belongs.
- vocab_size: int. Size of vocabulary for this embedding layer.
- num_uses: int or None. Number of uses of the layer in the model's graph.
- Only required if the data is formatted with time folded into the batch
- dimension (instead of time being a list dimension). (Default: None)
- """
- self._vocab_size = vocab_size
-
- super(EmbeddingKFACMultiIndepFB, self).__init__(
- layer_collection=layer_collection,
- num_uses=num_uses)
-
- def instantiate_factors(self, grads_list, damping):
- """Instantiate Kronecker Factors for this FisherBlock.
-
- Args:
- grads_list: List of list of list of Tensors. grads_list[i][j][k] is the
- gradient of the loss with respect to 'outputs' from source 'i',
- tower/mini-batch 'j', and use/time-step 'k'. Each Tensor has shape
- [tower_minibatch_size, output_size].
- damping: 0-D Tensor or float. 'damping' * identity is approximately added
- to this FisherBlock's Fisher approximation.
- """
- inputs, grads_list = self._process_data(grads_list)
-
- self._input_factor = self._layer_collection.make_or_get_factor(
- fisher_factors.EmbeddingInputKroneckerFactor,
- (inputs, self._vocab_size))
- self._output_factor = self._layer_collection.make_or_get_factor(
- fisher_factors.FullyConnectedMultiKF, (grads_list, self._num_uses))
- self._setup_damping(damping, normalization=self._num_uses)
-
- @property
- def _renorm_coeff(self):
- return float(self._num_uses)
-
-
-class SeriesFBApproximation(enum.IntEnum):
- """See FullyConnectedSeriesFB.__init__ for description and usage."""
- option1 = 1
- option2 = 2
-
-
-class FullyConnectedSeriesFB(InputOutputMultiTowerMultiUse,
- KroneckerProductFB):
- """FisherBlock for fully-connected layers that share parameters across time.
-
- This class implements the "Option 1" and "Option 2" approximation from the
- following paper:
- https://openreview.net/pdf?id=HyMTkQZAb
-
- See the end of the appendix of the paper for a pseudo-code of the
- algorithm being implemented by multiply_matpower here. Note that we are
- using pre-computed versions of certain matrix-matrix products to speed
- things up. This is explicitly explained wherever it is done.
- """
-
- def __init__(self,
- layer_collection,
- has_bias=False,
- num_uses=None,
- option=SeriesFBApproximation.option2):
- """Constructs a new `FullyConnectedSeriesFB`.
-
- Args:
- layer_collection: The collection of all layers in the K-FAC approximate
- Fisher information matrix to which this FisherBlock belongs.
- has_bias: Whether the layer includes a bias parameter.
- num_uses: int or None. Number of time-steps over which the layer
- is used. Only required if the data is formatted with time folded into
- the batch dimension (instead of time being a list dimension).
- (Default: None)
- option: A `SeriesFBApproximation` specifying the simplifying assumption
- to be used in this block. `option1` approximates the cross-covariance
- over time as a symmetric matrix, while `option2` makes
- the assumption that training sequences are infinitely long. See section
- 3.5 of the paper for more details.
- """
-
- self._has_bias = has_bias
- self._option = option
-
- super(FullyConnectedSeriesFB, self).__init__(
- layer_collection=layer_collection,
- num_uses=num_uses)
-
- @property
- def _num_timesteps(self):
- return self._num_uses
-
- @property
- def _renorm_coeff(self):
- # This should no longer be used since the multiply_X functions from the base
- # class have been overridden
- assert False
-
- def instantiate_factors(self, grads_list, damping):
- inputs, grads_list = self._process_data(grads_list)
-
- self._input_factor = self._layer_collection.make_or_get_factor(
- fisher_factors.FullyConnectedMultiKF,
- ((inputs,), self._num_uses, self._has_bias))
- self._input_factor.register_cov_dt1()
-
- self._output_factor = self._layer_collection.make_or_get_factor(
- fisher_factors.FullyConnectedMultiKF, (grads_list, self._num_uses))
- self._output_factor.register_cov_dt1()
-
- self._setup_damping(damping, normalization=self._num_uses)
-
- def register_matpower(self, exp):
- if exp != -1:
- raise NotImplementedError("FullyConnectedSeriesFB only supports inverse"
- "multiplications.")
-
- if self._option == SeriesFBApproximation.option1:
- self._input_factor.register_option1quants(self._input_damping_func)
- self._output_factor.register_option1quants(self._output_damping_func)
- elif self._option == SeriesFBApproximation.option2:
- self._input_factor.register_option2quants(self._input_damping_func)
- self._output_factor.register_option2quants(self._output_damping_func)
- else:
- raise ValueError(
- "Unrecognized FullyConnectedSeriesFB approximation: {}".format(
- self._option))
-
- def multiply_matpower(self, vector, exp):
- if exp != -1:
- raise NotImplementedError("FullyConnectedSeriesFB only supports inverse"
- "multiplications.")
-
- # pylint: disable=invalid-name
-
- Z = utils.layer_params_to_mat2d(vector)
-
- # Derivations were done for "batch_dim==1" case so we need to convert to
- # that orientation:
- Z = array_ops.transpose(Z)
-
- if self._option == SeriesFBApproximation.option1:
-
- # Note that \\(L_A = A0^{-1/2} * U_A and L_G = G0^{-1/2} * U_G.\\)
- L_A, psi_A = self._input_factor.get_option1quants(
- self._input_damping_func)
- L_G, psi_G = self._output_factor.get_option1quants(
- self._output_damping_func)
-
- def gamma(x):
- # We are assuming that each case has the same number of time-steps.
- # If this stops being the case one shouldn't simply replace this T
- # with its average value. Instead, one needs to go back to the
- # definition of the gamma function from the paper.
- T = self._num_timesteps
- return (1 - x)**2 / (T * (1 - x**2) - 2 * x * (1 - x**T))
-
- # \\(Y = \gamma( psi_G*psi_A^T )\\) (computed element-wise)
- # Even though Y is Z-independent we are recomputing it from the psi's
- # each since Y depends on both A and G quantities, and it is relatively
- # cheap to compute.
- Y = gamma(array_ops.reshape(psi_G, [int(psi_G.shape[0]), -1]) * psi_A)
-
- # \\(Z = L_G^T * Z * L_A\\)
- # This is equivalent to the following computation from the original
- # pseudo-code:
- # \\(Z = G0^{-1/2} * Z * A0^{-1/2}\\)
- # \\(Z = U_G^T * Z * U_A\\)
- Z = math_ops.matmul(L_G, math_ops.matmul(Z, L_A), transpose_a=True)
-
- # \\(Z = Z .* Y\\)
- Z *= Y
-
- # \\(Z = L_G * Z * L_A^T\\)
- # This is equivalent to the following computation from the original
- # pseudo-code:
- # \\(Z = U_G * Z * U_A^T\\)
- # \\(Z = G0^{-1/2} * Z * A0^{-1/2}\\)
- Z = math_ops.matmul(L_G, math_ops.matmul(Z, L_A, transpose_b=True))
-
- elif self._option == SeriesFBApproximation.option2:
-
- # Note that \\(P_A = A_1^T * A_0^{-1} and P_G = G_1^T * G_0^{-1}\\),
- # and \\(K_A = A_0^{-1/2} * E_A\ and\ K_G = G_0^{-1/2} * E_G.\\)
- P_A, K_A, mu_A = self._input_factor.get_option2quants(
- self._input_damping_func)
- P_G, K_G, mu_G = self._output_factor.get_option2quants(
- self._output_damping_func)
-
- # Our approach differs superficially from the pseudo-code in the paper
- # in order to reduce the total number of matrix-matrix multiplies.
- # In particular, the first three computations in the pseudo code are
- # \\(Z = G0^{-1/2} * Z * A0^{-1/2}\\)
- # \\(Z = Z - hPsi_G^T * Z * hPsi_A\\)
- # \\(Z = E_G^T * Z * E_A\\)
- # Noting that hPsi = C0^{-1/2} * C1 * C0^{-1/2}\\), so that
- # \\(C0^{-1/2} * hPsi = C0^{-1} * C1 * C0^{-1/2} = P^T * C0^{-1/2}\\)
- # the entire computation can be written as
- # \\(Z = E_G^T * (G0^{-1/2} * Z * A0^{-1/2}\\)
- # \\( - hPsi_G^T * G0^{-1/2} * Z * A0^{-1/2} * hPsi_A) * E_A\\)
- # \\( = E_G^T * (G0^{-1/2} * Z * A0^{-1/2}\\)
- # \\( - G0^{-1/2} * P_G * Z * P_A^T * A0^{-1/2}) * E_A\\)
- # \\( = E_G^T * G0^{-1/2} * Z * A0^{-1/2} * E_A\\)
- # \\( - E_G^T* G0^{-1/2} * P_G * Z * P_A^T * A0^{-1/2} * E_A\\)
- # \\( = K_G^T * Z * K_A - K_G^T * P_G * Z * P_A^T * K_A\\)
- # This final expression is computed by the following two lines:
- # \\(Z = Z - P_G * Z * P_A^T\\)
- Z -= math_ops.matmul(P_G, math_ops.matmul(Z, P_A, transpose_b=True))
- # \\(Z = K_G^T * Z * K_A\\)
- Z = math_ops.matmul(K_G, math_ops.matmul(Z, K_A), transpose_a=True)
-
- # \\(Z = Z ./ (1*1^T - mu_G*mu_A^T)\\)
- # Be careful with the outer product. We don't want to accidentally
- # make it an inner-product instead.
- tmp = 1.0 - array_ops.reshape(mu_G, [int(mu_G.shape[0]), -1]) * mu_A
- # Prevent some numerical issues by setting any 0.0 eigs to 1.0
- tmp += 1.0 * math_ops.cast(math_ops.equal(tmp, 0.0), dtype=tmp.dtype)
- Z /= tmp
-
- # We now perform the transpose/reverse version of the operations
- # derived above, whose derivation from the original pseudo-code is
- # analgous.
- # \\(Z = K_G * Z * K_A^T\\)
- Z = math_ops.matmul(K_G, math_ops.matmul(Z, K_A, transpose_b=True))
-
- # \\(Z = Z - P_G^T * Z * P_A\\)
- Z -= math_ops.matmul(P_G, math_ops.matmul(Z, P_A), transpose_a=True)
-
- # \\(Z = normalize (1/E[T]) * Z\\)
- # Note that this normalization is done because we compute the statistics
- # by averaging, not summing, over time. (And the gradient is presumably
- # summed over time, not averaged, and thus their scales are different.)
- Z /= math_ops.cast(self._num_timesteps, Z.dtype)
-
- # Convert back to the "batch_dim==0" orientation.
- Z = array_ops.transpose(Z)
-
- return utils.mat2d_to_layer_params(vector, Z)
-
- # pylint: enable=invalid-name
-
- def multiply_cholesky(self, vector):
- raise NotImplementedError("FullyConnectedSeriesFB does not support "
- "Cholesky computations.")
-
- def multiply_cholesky_inverse(self, vector):
- raise NotImplementedError("FullyConnectedSeriesFB does not support "
- "Cholesky computations.")
-
diff --git a/tensorflow/contrib/kfac/python/ops/fisher_blocks_lib.py b/tensorflow/contrib/kfac/python/ops/fisher_blocks_lib.py
deleted file mode 100644
index c04cf727fa..0000000000
--- a/tensorflow/contrib/kfac/python/ops/fisher_blocks_lib.py
+++ /dev/null
@@ -1,45 +0,0 @@
-# Copyright 2017 The TensorFlow Authors. All Rights Reserved.
-#
-# Licensed under the Apache License, Version 2.0 (the "License");
-# you may not use this file except in compliance with the License.
-# You may obtain a copy of the License at
-#
-# http://www.apache.org/licenses/LICENSE-2.0
-#
-# Unless required by applicable law or agreed to in writing, software
-# distributed under the License is distributed on an "AS IS" BASIS,
-# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-# See the License for the specific language governing permissions and
-# limitations under the License.
-# ==============================================================================
-"""FisherBlock definitions."""
-
-from __future__ import absolute_import
-from __future__ import division
-from __future__ import print_function
-
-# pylint: disable=unused-import,line-too-long,wildcard-import
-from tensorflow.contrib.kfac.python.ops.fisher_blocks import *
-from tensorflow.python.util.all_util import remove_undocumented
-# pylint: enable=unused-import,line-too-long,wildcard-import
-
-_allowed_symbols = [
- 'FisherBlock',
- 'FullFB',
- 'NaiveDiagonalFB',
- 'FullyConnectedDiagonalFB',
- 'KroneckerProductFB',
- 'EmbeddingKFACFB',
- 'FullyConnectedKFACBasicFB',
- 'ConvKFCBasicFB',
- 'ConvDiagonalFB',
- 'set_global_constants',
- 'compute_pi_tracenorm',
- 'compute_pi_adjusted_damping',
- 'num_conv_locations',
- 'normalize_damping',
- 'LEFT_MULTIPLY',
- 'RIGHT_MULTIPLY',
-]
-
-remove_undocumented(__name__, allowed_exception_list=_allowed_symbols)
diff --git a/tensorflow/contrib/kfac/python/ops/fisher_factors.py b/tensorflow/contrib/kfac/python/ops/fisher_factors.py
deleted file mode 100644
index afa2fd1ca7..0000000000
--- a/tensorflow/contrib/kfac/python/ops/fisher_factors.py
+++ /dev/null
@@ -1,1830 +0,0 @@
-# Copyright 2017 The TensorFlow Authors. All Rights Reserved.
-#
-# Licensed under the Apache License, Version 2.0 (the "License");
-# you may not use this file except in compliance with the License.
-# You may obtain a copy of the License at
-#
-# http://www.apache.org/licenses/LICENSE-2.0
-#
-# Unless required by applicable law or agreed to in writing, software
-# distributed under the License is distributed on an "AS IS" BASIS,
-# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-# See the License for the specific language governing permissions and
-# limitations under the License.
-# ==============================================================================
-"""FisherFactor definitions."""
-
-from __future__ import absolute_import
-from __future__ import division
-from __future__ import print_function
-
-import abc
-import contextlib
-
-import numpy as np
-import six
-
-from tensorflow.contrib.kfac.python.ops import linear_operator as lo
-from tensorflow.contrib.kfac.python.ops import utils
-from tensorflow.python.framework import dtypes
-from tensorflow.python.framework import ops as tf_ops
-from tensorflow.python.ops import array_ops
-from tensorflow.python.ops import control_flow_ops
-from tensorflow.python.ops import init_ops
-from tensorflow.python.ops import linalg_ops
-from tensorflow.python.ops import math_ops
-from tensorflow.python.ops import random_ops
-from tensorflow.python.ops import special_math_ops
-from tensorflow.python.ops import variable_scope
-from tensorflow.python.ops import variables
-from tensorflow.python.training import moving_averages
-from tensorflow.python.util import nest
-
-
-# Whether to initialize covariance estimators at a zero matrix (or the identity
-# matrix).
-INIT_COVARIANCES_AT_ZERO = True
-
-# Whether to zero-debias the moving averages.
-ZERO_DEBIAS = True
-
-# Whether to initialize inverse (and other such matrices computed from the cov
-# matrices) to the zero matrix (or the identity matrix).
-INIT_INVERSES_AT_ZERO = True
-
-# When the number of inverses requested from a FisherFactor exceeds this value,
-# the inverses are computed using an eigenvalue decomposition.
-EIGENVALUE_DECOMPOSITION_THRESHOLD = 2
-
-# Numerical eigenvalues computed from covariance matrix estimates are clipped to
-# be at least as large as this value before they are used to compute inverses or
-# matrix powers. Must be nonnegative.
-EIGENVALUE_CLIPPING_THRESHOLD = 0.0
-
-# Used to subsample the flattened extracted image patches. The number of
-# outer products per row of the covariance matrix should not exceed this
-# value. This parameter is used only if `_SUB_SAMPLE_OUTER_PRODUCTS` is True.
-_MAX_NUM_OUTER_PRODUCTS_PER_COV_ROW = 1
-
-# Used to subsample the inputs passed to the extract image patches. The batch
-# size of number of inputs to extract image patches is multiplied by this
-# factor. This parameter is used only if `_SUB_SAMPLE_INPUTS` is True.
-_INPUTS_TO_EXTRACT_PATCHES_FACTOR = 0.5
-
-# If True, then subsamples the tensor passed to compute the covariance matrix.
-_SUB_SAMPLE_OUTER_PRODUCTS = False
-
-# If True, then subsamples the tensor passed to compute the covariance matrix.
-_SUB_SAMPLE_INPUTS = False
-
-# TOWER_STRATEGY can be one of "concat" or "separate". If "concat", the data
-# passed to the factors from the blocks will be concatenated across towers
-# (lazily via PartitionedTensor objects). Otherwise a tuple of tensors over
-# towers will be passed in, and the factors will iterate over this and do the
-# cov computations separately for each one, averaging the results together.
-TOWER_STRATEGY = "concat"
-
-
-def set_global_constants(init_covariances_at_zero=None,
- zero_debias=None,
- init_inverses_at_zero=None,
- eigenvalue_decomposition_threshold=None,
- eigenvalue_clipping_threshold=None,
- max_num_outer_products_per_cov_row=None,
- sub_sample_outer_products=None,
- inputs_to_extract_patches_factor=None,
- sub_sample_inputs=None,
- tower_strategy=None):
- """Sets various global constants used by the classes in this module."""
- global INIT_COVARIANCES_AT_ZERO
- global ZERO_DEBIAS
- global INIT_INVERSES_AT_ZERO
- global EIGENVALUE_DECOMPOSITION_THRESHOLD
- global EIGENVALUE_CLIPPING_THRESHOLD
- global _MAX_NUM_OUTER_PRODUCTS_PER_COV_ROW
- global _SUB_SAMPLE_OUTER_PRODUCTS
- global _INPUTS_TO_EXTRACT_PATCHES_FACTOR
- global _SUB_SAMPLE_INPUTS
- global TOWER_STRATEGY
-
- if init_covariances_at_zero is not None:
- INIT_COVARIANCES_AT_ZERO = init_covariances_at_zero
- if zero_debias is not None:
- ZERO_DEBIAS = zero_debias
- if init_inverses_at_zero is not None:
- INIT_INVERSES_AT_ZERO = init_inverses_at_zero
- if eigenvalue_decomposition_threshold is not None:
- EIGENVALUE_DECOMPOSITION_THRESHOLD = eigenvalue_decomposition_threshold
- if eigenvalue_clipping_threshold is not None:
- EIGENVALUE_CLIPPING_THRESHOLD = eigenvalue_clipping_threshold
- if max_num_outer_products_per_cov_row is not None:
- _MAX_NUM_OUTER_PRODUCTS_PER_COV_ROW = max_num_outer_products_per_cov_row
- if sub_sample_outer_products is not None:
- _SUB_SAMPLE_OUTER_PRODUCTS = sub_sample_outer_products
- if inputs_to_extract_patches_factor is not None:
- _INPUTS_TO_EXTRACT_PATCHES_FACTOR = inputs_to_extract_patches_factor
- if sub_sample_inputs is not None:
- _SUB_SAMPLE_INPUTS = sub_sample_inputs
- if tower_strategy is not None:
- TOWER_STRATEGY = tower_strategy
-
-
-def inverse_initializer(shape, dtype, partition_info=None): # pylint: disable=unused-argument
- if INIT_INVERSES_AT_ZERO:
- return array_ops.zeros(shape, dtype=dtype)
- return linalg_ops.eye(num_rows=shape[0], dtype=dtype)
-
-
-def covariance_initializer(shape, dtype, partition_info=None): # pylint: disable=unused-argument
- if INIT_COVARIANCES_AT_ZERO:
- return array_ops.zeros(shape, dtype=dtype)
- return linalg_ops.eye(num_rows=shape[0], dtype=dtype)
-
-
-def diagonal_covariance_initializer(shape, dtype, partition_info=None): # pylint: disable=unused-argument
- if INIT_COVARIANCES_AT_ZERO:
- return array_ops.zeros(shape, dtype=dtype)
- return array_ops.ones(shape, dtype=dtype)
-
-
-@contextlib.contextmanager
-def place_on_device(device):
- if device is not None and len(device):
- with tf_ops.device(device):
- yield
- else:
- yield
-
-
-def compute_cov(tensor, tensor_right=None, normalizer=None):
- """Compute the empirical second moment of the rows of a 2D Tensor.
-
- This function is meant to be applied to random matrices for which the true row
- mean is zero, so that the true second moment equals the true covariance.
-
- Args:
- tensor: A 2D Tensor.
- tensor_right: An optional 2D Tensor. If provided, this function computes
- the matrix product tensor^T * tensor_right instead of tensor^T * tensor.
- normalizer: optional scalar for the estimator (by default, the normalizer is
- the number of rows of tensor).
-
- Returns:
- A square 2D Tensor with as many rows/cols as the number of input columns.
- """
- if normalizer is None:
- normalizer = array_ops.shape(tensor)[0]
- if tensor_right is None:
- cov = (
- math_ops.matmul(tensor, tensor, transpose_a=True) / math_ops.cast(
- normalizer, tensor.dtype))
- return (cov + array_ops.transpose(cov)) / math_ops.cast(2.0, cov.dtype)
- else:
- return (math_ops.matmul(tensor, tensor_right, transpose_a=True) /
- math_ops.cast(normalizer, tensor.dtype))
-
-
-def append_homog(tensor):
- """Appends a homogeneous coordinate to the last dimension of a Tensor.
-
- Args:
- tensor: A Tensor.
-
- Returns:
- A Tensor identical to the input but one larger in the last dimension. The
- new entries are filled with ones.
- """
- rank = len(tensor.shape.as_list())
- shape = array_ops.concat([array_ops.shape(tensor)[:-1], [1]], axis=0)
- ones = array_ops.ones(shape, dtype=tensor.dtype)
- return array_ops.concat([tensor, ones], axis=rank - 1)
-
-
-def scope_string_from_params(params):
- """Builds a variable scope string name from the given parameters.
-
- Supported parameters are:
- * tensors
- * booleans
- * ints
- * strings
- * depth-1 tuples/lists of ints
- * any depth tuples/lists of tensors
- Other parameter types will throw an error.
-
- Args:
- params: A parameter or list of parameters.
-
- Returns:
- A string to use for the variable scope.
-
- Raises:
- ValueError: if params includes an unsupported type.
- """
- params = params if isinstance(params, (tuple, list)) else (params,)
-
- name_parts = []
- for param in params:
- if param is None:
- name_parts.append("None")
- elif isinstance(param, (tuple, list)):
- if all([isinstance(p, int) for p in param]):
- name_parts.append("-".join([str(p) for p in param]))
- else:
- name_parts.append(scope_string_from_name(param))
- elif isinstance(param, (str, int, bool)):
- name_parts.append(str(param))
- elif isinstance(param, (tf_ops.Tensor, variables.Variable)):
- name_parts.append(scope_string_from_name(param))
- elif isinstance(param, utils.PartitionedTensor):
- name_parts.append(scope_string_from_name(param.tensors))
- else:
- raise ValueError("Encountered an unsupported param type {}".format(
- type(param)))
- return "_".join(name_parts)
-
-
-def scope_string_from_name(tensor):
- if isinstance(tensor, (tuple, list)):
- return "__".join([scope_string_from_name(t) for t in tensor])
- # "gradients/add_4_grad/Reshape:0" -> "gradients_add_4_grad_Reshape"
- return tensor.name.split(":")[0].replace("/", "_")
-
-
-def scalar_or_tensor_to_string(val):
- return repr(val) if np.isscalar(val) else scope_string_from_name(val)
-
-
-def list_to_string(lst):
- return "_".join(val if isinstance(val, six.string_types)
- else scalar_or_tensor_to_string(val) for val in lst)
-
-
-def graph_func_to_id(func):
- """Returns a hashable object that represents func's computation."""
- # TODO(b/74201126): replace with Topohash of func's output
- return func.func_id
-
-
-def graph_func_to_string(func):
- # TODO(b/74201126): replace with Topohash of func's output
- return list_to_string(func.func_id)
-
-
-def _subsample_for_cov_computation(array, name=None):
- """Subsamples the first dimension of the array.
-
- `array`(A) is a tensor of shape `[batch_size, dim_2]`. Then the covariance
- matrix(A^TA) is of shape `dim_2 ** 2`. Subsample only if the number of outer
- products per row of the covariance matrix is greater than
- `_MAX_NUM_OUTER_PRODUCTS_PER_COV_ROW`.
-
- Args:
- array: Tensor, of shape `[batch_size, dim_2]`.
- name: `string`, Default(None)
-
- Returns:
- A tensor of shape `[max_samples, dim_2]`.
-
- Raises:
- ValueError: If array's is not matrix-shaped.
- ValueError: If array's batch_size cannot be inferred.
-
- """
- with tf_ops.name_scope(name, "subsample", [array]):
- array = tf_ops.convert_to_tensor(array)
- if len(array.shape) != 2:
- raise ValueError("Input param array must be a matrix.")
-
- batch_size = array.shape.as_list()[0]
- if batch_size is None:
- raise ValueError("Unable to get batch_size from input param array.")
-
- num_cov_rows = array.shape.as_list()[-1]
- max_batch_size = int(_MAX_NUM_OUTER_PRODUCTS_PER_COV_ROW * num_cov_rows)
- if batch_size <= max_batch_size:
- return array
-
- return _random_tensor_gather(array, max_batch_size)
-
-
-def _random_tensor_gather(array, max_size):
- """Generates a random set of indices and gathers the value at the indices.
-
- Args:
- array: Tensor, of shape `[batch_size, dim_2]`.
- max_size: int, Number of indices to sample.
-
- Returns:
- A tensor of shape `[max_size, ...]`.
- """
- batch_size = array.shape.as_list()[0]
- indices = random_ops.random_shuffle(math_ops.range(0, batch_size))[:max_size]
- return array_ops.gather(array, indices)
-
-
-@six.add_metaclass(abc.ABCMeta)
-class FisherFactor(object):
- """Base class for objects modeling factors of approximate Fisher blocks.
-
- A FisherFactor represents part of an approximate Fisher Information matrix.
- For example, one approximation to the Fisher uses the Kronecker product of two
- FisherFactors A and B, F = kron(A, B). FisherFactors are composed with
- FisherBlocks to construct a block-diagonal approximation to the full Fisher.
-
- FisherFactors are backed by a single, non-trainable variable that is updated
- by running FisherFactor.make_covariance_update_op(). The shape and type of
- this variable is implementation specific.
-
- Note that for blocks that aren't based on approximations, a 'factor' can
- be the entire block itself, as is the case for the diagonal and full
- representations.
- """
-
- def __init__(self):
- self._cov = None
-
- @abc.abstractproperty
- def _var_scope(self):
- """Variable scope for this FisherFactor instance.
-
- Returns:
- string that unique identifies this FisherFactor instance.
- """
- pass
-
- @property
- def name(self):
- return self._var_scope
-
- @abc.abstractproperty
- def _cov_shape(self):
- """The shape of the variable backing this FisherFactor."""
- pass
-
- @abc.abstractproperty
- def _num_sources(self):
- """The number of things to sum over when updating covariance variable.
-
- The default make_covariance_update_op function will call _compute_new_cov
- with indices ranging from 0 to _num_sources-1. The typical situation is
- where the factor wants to sum the statistics it computes over multiple
- backpropped "gradients" (typically passed in via "tensors" or
- "outputs_grads" arguments).
- """
- pass
-
- @abc.abstractproperty
- def _num_towers(self):
- pass
-
- @abc.abstractproperty
- def _dtype(self):
- """dtype for variable backing this factor."""
- pass
-
- @property
- def _cov_initializer(self):
- """Function for initializing covariance variable."""
- return covariance_initializer
-
- def instantiate_cov_variables(self):
- """Makes the internal cov variable(s)."""
- assert self._cov is None
- with variable_scope.variable_scope(self._var_scope):
- self._cov = variable_scope.get_variable(
- "cov",
- initializer=self._cov_initializer,
- shape=self._cov_shape,
- trainable=False,
- dtype=self._dtype)
-
- @abc.abstractmethod
- def _compute_new_cov(self, source, tower):
- """Computes minibatch-estimated covariance for a single source.
-
- Args:
- source: int in [0, self._num_sources). Which source to use when computing
- the cov update.
- tower: int in [0, self._num_towers). Which tower to use when computing
- the cov update.
-
- Returns:
- Tensor of same shape as self.get_cov().
- """
- pass
-
- def make_covariance_update_op(self, ema_decay):
- """Constructs and returns the covariance update Op.
-
- Args:
- ema_decay: The exponential moving average decay (float or Tensor).
- Returns:
- An Op for updating the covariance Variable referenced by _cov.
- """
- new_cov_contribs = []
- for source in range(self._num_sources):
- for tower in range(self._num_towers):
- device = (self._get_data_device(tower)
- if TOWER_STRATEGY == "separate" else None)
- with place_on_device(device):
- new_cov_contribs.append(self._compute_new_cov(source, tower))
-
- new_cov = math_ops.add_n(new_cov_contribs) / float(self._num_towers)
-
- # Compute average of 'new_cov' across all TPU cores. On a TPU, each
- # instance of 'new_cov' will be based on a different minibatch. This ensures
- # that by the end of assign_moving_average(), all TPU cores see the same
- # value for self._cov.
- #
- # Other implementations of make_covariance_update_op() that accumulate
- # statistics in other variables should mimic this behavior.
- if utils.on_tpu():
- new_cov = utils.cross_replica_mean(new_cov)
-
- return moving_averages.assign_moving_average(
- self._cov, new_cov, ema_decay, zero_debias=ZERO_DEBIAS)
-
- @abc.abstractmethod
- def _get_data_device(self, tower):
- pass
-
- @abc.abstractmethod
- def instantiate_inv_variables(self):
- """Makes the internal "inverse" variable(s)."""
- pass
-
- @abc.abstractmethod
- def make_inverse_update_ops(self):
- """Create and return update ops corresponding to registered computations."""
- pass
-
- def get_cov(self):
- return self._cov
-
- @abc.abstractmethod
- def get_cov_as_linear_operator(self):
- pass
-
- @abc.abstractmethod
- def register_matpower(self, exp, damping_func):
- pass
-
- @abc.abstractmethod
- def register_cholesky(self, damping_func):
- pass
-
- @abc.abstractmethod
- def register_cholesky_inverse(self, damping_func):
- pass
-
- @abc.abstractmethod
- def get_matpower(self, exp, damping_func):
- pass
-
- @abc.abstractmethod
- def get_cholesky(self, damping_func):
- pass
-
- @abc.abstractmethod
- def get_cholesky_inverse(self, damping_func):
- pass
-
-
-class DenseSquareMatrixFactor(FisherFactor):
- """Base class for FisherFactors that are stored as dense square matrices.
-
- This class explicitly calculates and stores inverses of their `cov` matrices,
- which must be square dense matrices.
-
- Subclasses must implement the _compute_new_cov method, and the _var_scope and
- _cov_shape properties.
- """
-
- # TODO(b/69108481): This class (and its subclasses) should be refactored to
- # serve the matrix quantities it computes as both (potentially stale)
- # variables, updated by the inverse update ops, and fresh values stored in
- # tensors that recomputed once every session.run() call. Currently matpower
- # and damp_inverse have the former behavior, while eigendecomposition has
- # the latter.
-
- def __init__(self):
- self._matpower_by_exp_and_damping = {} # { (float, hashable): variable }
- self._matpower_registrations = set() # { (float, hashable) }
- self._eigendecomp = None
- self._damping_funcs_by_id = {} # {hashable: lambda}
-
- self._cholesky_registrations = set() # { hashable }
- self._cholesky_inverse_registrations = set() # { hashable }
-
- self._cholesky_by_damping = {} # { hashable: variable }
- self._cholesky_inverse_by_damping = {} # { hashable: variable }
-
- super(DenseSquareMatrixFactor, self).__init__()
-
- def get_cov_as_linear_operator(self):
- assert self.get_cov().shape.ndims == 2
- return lo.LinearOperatorFullMatrix(self.get_cov(),
- is_self_adjoint=True,
- is_square=True)
-
- def _register_damping(self, damping_func):
- damping_id = graph_func_to_id(damping_func)
- if damping_id not in self._damping_funcs_by_id:
- self._damping_funcs_by_id[damping_id] = damping_func
- return damping_id
-
- def register_inverse(self, damping_func):
- # Just for backwards compatibility of some old code and tests
- self.register_matpower(-1, damping_func)
-
- def register_matpower(self, exp, damping_func):
- """Registers a matrix power to be maintained and served on demand.
-
- This creates a variable and signals make_inverse_update_ops to make the
- corresponding update op. The variable can be read via the method
- get_matpower.
-
- Args:
- exp: float. The exponent to use in the matrix power.
- damping_func: A function that computes a 0-D Tensor or a float which will
- be the damping value used. i.e. damping = damping_func().
- """
- if exp == 1.0:
- return
-
- damping_id = self._register_damping(damping_func)
-
- if (exp, damping_id) not in self._matpower_registrations:
- self._matpower_registrations.add((exp, damping_id))
-
- def register_cholesky(self, damping_func):
- """Registers a Cholesky factor to be maintained and served on demand.
-
- This creates a variable and signals make_inverse_update_ops to make the
- corresponding update op. The variable can be read via the method
- get_cholesky.
-
- Args:
- damping_func: A function that computes a 0-D Tensor or a float which will
- be the damping value used. i.e. damping = damping_func().
- """
- damping_id = self._register_damping(damping_func)
-
- if damping_id not in self._cholesky_registrations:
- self._cholesky_registrations.add(damping_id)
-
- def register_cholesky_inverse(self, damping_func):
- """Registers an inverse Cholesky factor to be maintained/served on demand.
-
- This creates a variable and signals make_inverse_update_ops to make the
- corresponding update op. The variable can be read via the method
- get_cholesky_inverse.
-
- Args:
- damping_func: A function that computes a 0-D Tensor or a float which will
- be the damping value used. i.e. damping = damping_func().
- """
- damping_id = self._register_damping(damping_func)
-
- if damping_id not in self._cholesky_inverse_registrations:
- self._cholesky_inverse_registrations.add(damping_id)
-
- def instantiate_inv_variables(self):
- """Makes the internal "inverse" variable(s)."""
-
- for (exp, damping_id) in self._matpower_registrations:
- exp_string = scalar_or_tensor_to_string(exp)
- damping_func = self._damping_funcs_by_id[damping_id]
- damping_string = graph_func_to_string(damping_func)
- with variable_scope.variable_scope(self._var_scope):
- matpower = variable_scope.get_variable(
- "matpower_exp{}_damp{}".format(exp_string, damping_string),
- initializer=inverse_initializer,
- shape=self._cov_shape,
- trainable=False,
- dtype=self._dtype)
- assert (exp, damping_id) not in self._matpower_by_exp_and_damping
- self._matpower_by_exp_and_damping[(exp, damping_id)] = matpower
-
- for damping_id in self._cholesky_registrations:
- damping_func = self._damping_funcs_by_id[damping_id]
- damping_string = graph_func_to_string(damping_func)
- with variable_scope.variable_scope(self._var_scope):
- chol = variable_scope.get_variable(
- "cholesky_damp{}".format(damping_string),
- initializer=inverse_initializer,
- shape=self._cov_shape,
- trainable=False,
- dtype=self._dtype)
- assert damping_id not in self._cholesky_by_damping
- self._cholesky_by_damping[damping_id] = chol
-
- for damping_id in self._cholesky_inverse_registrations:
- damping_func = self._damping_funcs_by_id[damping_id]
- damping_string = graph_func_to_string(damping_func)
- with variable_scope.variable_scope(self._var_scope):
- cholinv = variable_scope.get_variable(
- "cholesky_inverse_damp{}".format(damping_string),
- initializer=inverse_initializer,
- shape=self._cov_shape,
- trainable=False,
- dtype=self._dtype)
- assert damping_id not in self._cholesky_inverse_by_damping
- self._cholesky_inverse_by_damping[damping_id] = cholinv
-
- def make_inverse_update_ops(self):
- """Create and return update ops corresponding to registered computations."""
- ops = []
-
- num_inverses = sum(1 for (exp, _) in self._matpower_by_exp_and_damping
- if exp == -1)
-
- num_other_matpower = len(self._matpower_by_exp_and_damping) - num_inverses
-
- other_matrix_power_registered = num_other_matpower >= 1
-
- use_eig = (
- self._eigendecomp or other_matrix_power_registered or
- num_inverses >= EIGENVALUE_DECOMPOSITION_THRESHOLD)
-
- # We precompute these so we don't need to evaluate them multiple times (for
- # each matrix power that uses them)
- damping_value_by_id = {damping_id: math_ops.cast(
- self._damping_funcs_by_id[damping_id](), self._dtype)
- for damping_id in self._damping_funcs_by_id}
-
- if use_eig:
- eigenvalues, eigenvectors = self.get_eigendecomp() # pylint: disable=unpacking-non-sequence
-
- for (exp, damping_id), matpower in (
- self._matpower_by_exp_and_damping.items()):
- damping = damping_value_by_id[damping_id]
- ops.append(
- matpower.assign(
- math_ops.matmul(eigenvectors *
- (eigenvalues + damping)**exp,
- array_ops.transpose(eigenvectors))))
- # These ops share computation and should be run on a single device.
- ops = [control_flow_ops.group(*ops)]
- else:
- for (exp, damping_id), matpower in (
- self._matpower_by_exp_and_damping.items()):
- assert exp == -1
- damping = damping_value_by_id[damping_id]
- ops.append(matpower.assign(utils.posdef_inv(self.get_cov(), damping)))
-
- # TODO(b/77902055): If inverses are being computed with Cholesky's
- # we can share the work. Instead this code currently just computes the
- # Cholesky a second time. It does at least share work between requests for
- # Cholesky's and Cholesky inverses with the same damping id.
- for damping_id, cholesky_inv in self._cholesky_inverse_by_damping.items():
- cholesky_ops = []
-
- damping = damping_value_by_id[damping_id]
- cholesky_value = utils.cholesky(self.get_cov(), damping)
-
- if damping_id in self._cholesky_by_damping:
- cholesky = self._cholesky_by_damping[damping_id]
- cholesky_ops.append(cholesky.assign(cholesky_value))
-
- identity = linalg_ops.eye(cholesky_value.shape.as_list()[0],
- dtype=cholesky_value.dtype)
- cholesky_inv_value = linalg_ops.matrix_triangular_solve(cholesky_value,
- identity)
- cholesky_ops.append(cholesky_inv.assign(cholesky_inv_value))
-
- ops.append(control_flow_ops.group(*cholesky_ops))
-
- for damping_id, cholesky in self._cholesky_by_damping.items():
- if damping_id not in self._cholesky_inverse_by_damping:
- damping = damping_value_by_id[damping_id]
- cholesky_value = utils.cholesky(self.get_cov(), damping)
- ops.append(cholesky.assign(cholesky_value))
-
- self._eigendecomp = False
- return ops
-
- def get_inverse(self, damping_func):
- # Just for backwards compatibility of some old code and tests
- return self.get_matpower(-1, damping_func)
-
- def get_matpower(self, exp, damping_func):
- # Note that this function returns a variable which gets updated by the
- # inverse ops. It may be stale / inconsistent with the latest value of
- # get_cov().
- if exp != 1:
- damping_id = graph_func_to_id(damping_func)
- matpower = self._matpower_by_exp_and_damping[(exp, damping_id)]
- else:
- matpower = self.get_cov()
- identity = linalg_ops.eye(matpower.shape.as_list()[0],
- dtype=matpower.dtype)
- matpower += math_ops.cast(damping_func(), dtype=matpower.dtype)*identity
-
- assert matpower.shape.ndims == 2
- return lo.LinearOperatorFullMatrix(matpower,
- is_non_singular=True,
- is_self_adjoint=True,
- is_positive_definite=True,
- is_square=True)
-
- def get_cholesky(self, damping_func):
- # Note that this function returns a variable which gets updated by the
- # inverse ops. It may be stale / inconsistent with the latest value of
- # get_cov().
- damping_id = graph_func_to_id(damping_func)
- cholesky = self._cholesky_by_damping[damping_id]
- assert cholesky.shape.ndims == 2
- return lo.LinearOperatorFullMatrix(cholesky,
- is_non_singular=True,
- is_square=True)
-
- def get_cholesky_inverse(self, damping_func):
- # Note that this function returns a variable which gets updated by the
- # inverse ops. It may be stale / inconsistent with the latest value of
- # get_cov().
- damping_id = graph_func_to_id(damping_func)
- cholesky_inv = self._cholesky_inverse_by_damping[damping_id]
- assert cholesky_inv.shape.ndims == 2
- return lo.LinearOperatorFullMatrix(cholesky_inv,
- is_non_singular=True,
- is_square=True)
-
- def get_eigendecomp(self):
- """Creates or retrieves eigendecomposition of self._cov."""
- # Unlike get_matpower this doesn't retrieve a stored variable, but instead
- # always computes a fresh version from the current value of get_cov().
- if not self._eigendecomp:
- eigenvalues, eigenvectors = linalg_ops.self_adjoint_eig(self.get_cov())
-
- # The matrix self._cov is positive semidefinite by construction, but the
- # numerical eigenvalues could be negative due to numerical errors, so here
- # we clip them to be at least FLAGS.eigenvalue_clipping_threshold
- clipped_eigenvalues = math_ops.maximum(eigenvalues,
- EIGENVALUE_CLIPPING_THRESHOLD)
- self._eigendecomp = (clipped_eigenvalues, eigenvectors)
-
- return self._eigendecomp
-
-
-class FullFactor(DenseSquareMatrixFactor):
- """FisherFactor for a full matrix representation of the Fisher of a parameter.
-
- Note that this uses the naive "square the sum estimator", and so is applicable
- to any type of parameter in principle, but has very high variance.
- """
-
- def __init__(self,
- params_grads,
- batch_size):
- self._batch_size = batch_size
- self._params_grads = tuple(utils.ensure_sequence(params_grad)
- for params_grad in params_grads)
- super(FullFactor, self).__init__()
-
- @property
- def _var_scope(self):
- return "ff_full_" + scope_string_from_params(
- [self._params_grads, self._batch_size])
-
- @property
- def _cov_shape(self):
- size = sum(param_grad.shape.num_elements()
- for param_grad in self._params_grads[0])
- return (size, size)
-
- @property
- def _num_sources(self):
- return len(self._params_grads)
-
- @property
- def _num_towers(self):
- return 1
-
- @property
- def _dtype(self):
- return self._params_grads[0][0].dtype
-
- def _compute_new_cov(self, source, tower):
- assert tower == 0
-
- # This will be a very basic rank 1 estimate
- params_grads_flat = utils.tensors_to_column(self._params_grads[source])
- return ((params_grads_flat * array_ops.transpose(
- params_grads_flat)) / math_ops.cast(self._batch_size,
- params_grads_flat.dtype))
-
- def _get_data_device(self, tower):
- return None
-
-
-class DiagonalFactor(FisherFactor):
- """A base class for FisherFactors that use diagonal approximations.
-
- A DiagonalFactor's covariance variable can be of any shape, but must contain
- exactly one entry per parameter.
- """
-
- def __init__(self):
- super(DiagonalFactor, self).__init__()
-
- def get_cov_as_linear_operator(self):
- assert self._matrix_diagonal.shape.ndims == 1
- return lo.LinearOperatorDiag(self._matrix_diagonal,
- is_self_adjoint=True,
- is_square=True)
-
- @property
- def _cov_initializer(self):
- return diagonal_covariance_initializer
-
- @property
- def _matrix_diagonal(self):
- return array_ops.reshape(self.get_cov(), [-1])
-
- def make_inverse_update_ops(self):
- return []
-
- def instantiate_inv_variables(self):
- pass
-
- def register_matpower(self, exp, damping_func):
- pass
-
- def register_cholesky(self, damping_func):
- pass
-
- def register_cholesky_inverse(self, damping_func):
- pass
-
- def get_matpower(self, exp, damping_func):
- matpower_diagonal = (self._matrix_diagonal
- + math_ops.cast(damping_func(), self._dtype))**exp
- return lo.LinearOperatorDiag(matpower_diagonal,
- is_non_singular=True,
- is_self_adjoint=True,
- is_positive_definite=True,
- is_square=True)
-
- def get_cholesky(self, damping_func):
- return self.get_matpower(0.5, damping_func)
-
- def get_cholesky_inverse(self, damping_func):
- return self.get_matpower(-0.5, damping_func)
-
-
-class NaiveDiagonalFactor(DiagonalFactor):
- """FisherFactor for a diagonal approximation of any type of param's Fisher.
-
- Note that this uses the naive "square the sum estimator", and so is applicable
- to any type of parameter in principle, but has very high variance.
- """
-
- def __init__(self,
- params_grads,
- batch_size):
- """Initializes NaiveDiagonalFactor instance.
-
- Args:
- params_grads: Sequence of Tensors, each with same shape as parameters this
- FisherFactor corresponds to. For example, the gradient of the loss with
- respect to parameters.
- batch_size: int or 0-D Tensor. Size
- """
- self._params_grads = tuple(utils.ensure_sequence(params_grad)
- for params_grad in params_grads)
- self._batch_size = batch_size
- super(NaiveDiagonalFactor, self).__init__()
-
- @property
- def _var_scope(self):
- return "ff_naivediag_" + scope_string_from_params(
- [self._params_grads, self._batch_size])
-
- @property
- def _cov_shape(self):
- size = sum(param_grad.shape.num_elements()
- for param_grad in self._params_grads[0])
- return [size, 1]
-
- @property
- def _num_sources(self):
- return len(self._params_grads)
-
- @property
- def _num_towers(self):
- return 1
-
- @property
- def _dtype(self):
- return self._params_grads[0][0].dtype
-
- def _compute_new_cov(self, source, tower):
- assert tower == 0
-
- params_grads_flat = utils.tensors_to_column(self._params_grads[source])
- return (math_ops.square(params_grads_flat) / math_ops.cast(
- self._batch_size, params_grads_flat.dtype))
-
- def _get_data_device(self, tower):
- return None
-
-
-class EmbeddingInputKroneckerFactor(DiagonalFactor):
- r"""FisherFactor for input to an embedding layer.
-
- Given input_ids = [batch_size, input_size] representing indices into an
- [vocab_size, embedding_size] embedding matrix, approximate input covariance by
- a diagonal matrix,
-
- Cov(input_ids, input_ids) =
- (1/batch_size) sum_{i} diag(n_hot(input[i]) ** 2).
-
- where n_hot() constructs an n-hot binary vector and diag() constructs a
- diagonal matrix of size [vocab_size, vocab_size].
- """
-
- def __init__(self, input_ids, vocab_size, dtype=None):
- """Instantiate EmbeddingInputKroneckerFactor.
-
- Args:
- input_ids: List of Tensors of shape [batch_size, input_size] and dtype
- int32. Indices into embedding matrix. List index is tower.
- vocab_size: int or 0-D Tensor. Maximum value for entries in 'input_ids'.
- dtype: dtype for covariance statistics. Must be a floating point type.
- Defaults to float32.
- """
- self._input_ids = input_ids
- self._vocab_size = vocab_size
- self._cov_dtype = dtype or dtypes.float32
-
- super(EmbeddingInputKroneckerFactor, self).__init__()
-
- @property
- def _var_scope(self):
- return "ff_diag_embedding_" + scope_string_from_params(self._input_ids)
-
- @property
- def _cov_shape(self):
- return [self._vocab_size]
-
- @property
- def _num_sources(self):
- return 1
-
- @property
- def _num_towers(self):
- return len(self._input_ids)
-
- @property
- def _dtype(self):
- return self._cov_dtype
-
- def _compute_new_cov(self, source, tower):
- assert source == 0
-
- input_ids = self._input_ids[tower]
-
- if len(input_ids.shape) > 2:
- raise ValueError(
- "Input to embeddings must have rank <= 2. Found rank %d." % len(
- input_ids.shape))
-
- batch_size = array_ops.shape(input_ids)[0]
-
- # Transform indices into one-hot vectors.
- #
- # TODO(b/72714822): There must be a faster way to construct the diagonal
- # covariance matrix! This operation is O(batch_size * vocab_size), where
- # it should be O(batch_size * input_size).
- flat_input_ids = array_ops.reshape(input_ids, [-1])
- one_hots = array_ops.one_hot(flat_input_ids,
- self._vocab_size) # [?, vocab_size]
-
- # Take average across examples. Note that, because all entries have
- # magnitude zero or one, there's no need to square the entries.
- #
- # TODO(b/72714822): Support for SparseTensor, other kinds of aggregation
- # within an example such as average.
- #
- # TODO(b/72714822): Support for partitioned embeddings.
- new_cov = math_ops.reduce_sum(one_hots, axis=0) # [vocab_size]
- new_cov /= math_ops.cast(batch_size, new_cov.dtype)
-
- return new_cov
-
- def _get_data_device(self, tower):
- return self._input_ids[tower].device
-
-
-class FullyConnectedDiagonalFactor(DiagonalFactor):
- r"""FisherFactor for a diagonal approx of a fully-connected layer's Fisher.
-
- Given in = [batch_size, input_size] and out_grad = [batch_size, output_size],
- approximates the covariance as,
-
- Cov(in, out) = (1/batch_size) sum_{i} outer(in[i], out_grad[i]) ** 2.0
-
- where the square is taken element-wise.
- """
-
- def __init__(self,
- inputs,
- outputs_grads,
- has_bias=False):
- """Instantiate FullyConnectedDiagonalFactor.
-
- Args:
- inputs: List of Tensors of shape [batch_size, input_size]. Inputs to this
- layer. List index is towers.
- outputs_grads: List of Tensors, each of shape [batch_size, output_size],
- which are the gradients of the loss with respect to the layer's
- outputs. First index is source, second is tower.
-
- has_bias: bool. If True, append '1' to each input.
- """
- self._inputs = inputs
- self._has_bias = has_bias
- self._outputs_grads = outputs_grads
- self._squared_inputs = None
-
- super(FullyConnectedDiagonalFactor, self).__init__()
-
- @property
- def _var_scope(self):
- return "ff_diagfc_" + scope_string_from_params(
- tuple(self._inputs) + tuple(nest.flatten(self._outputs_grads)))
-
- @property
- def _cov_shape(self):
- input_size = self._inputs[0].shape[1] + self._has_bias
- output_size = self._outputs_grads[0][0].shape[1]
- return [input_size, output_size]
-
- @property
- def _num_sources(self):
- return len(self._outputs_grads)
-
- @property
- def _num_towers(self):
- return len(self._inputs)
-
- @property
- def _dtype(self):
- return self._outputs_grads[0][0].dtype
-
- def make_covariance_update_op(self, ema_decay):
-
- self._squared_inputs = []
- for tower in range(self._num_towers):
- inputs = self._inputs[tower]
-
- with place_on_device(self._get_data_device(tower)):
- if self._has_bias:
- inputs = append_homog(inputs)
- self._squared_inputs.append(math_ops.square(inputs))
-
- return super(FullyConnectedDiagonalFactor, self).make_covariance_update_op(
- ema_decay)
-
- def _compute_new_cov(self, source, tower):
- batch_size = array_ops.shape(self._squared_inputs[tower])[0]
- outputs_grad = self._outputs_grads[source][tower]
-
- # The well-known special formula that uses the fact that the entry-wise
- # square of an outer product is the outer-product of the entry-wise squares.
- # The gradient is the outer product of the input and the output gradients,
- # so we just square both and then take their outer-product.
- new_cov = math_ops.matmul(
- self._squared_inputs[tower],
- math_ops.square(outputs_grad),
- transpose_a=True)
- new_cov /= math_ops.cast(batch_size, new_cov.dtype)
- return new_cov
-
- def _get_data_device(self, tower):
- return self._inputs[tower].device
-
-
-class ConvDiagonalFactor(DiagonalFactor):
- """FisherFactor for a diagonal approx of a convolutional layer's Fisher."""
-
- def __init__(self,
- inputs,
- outputs_grads,
- filter_shape,
- strides,
- padding,
- data_format=None,
- dilations=None,
- has_bias=False):
- """Creates a ConvDiagonalFactor object.
-
- Args:
- inputs: List of Tensors of shape [batch_size, height, width, in_channels].
- Input activations to this layer. List index is towers.
- outputs_grads: List of Tensors, each of shape [batch_size,
- height, width, out_channels], which are the gradients of the loss
- with respect to the layer's outputs. First index is source, second
- index is tower.
- filter_shape: Tuple of 4 ints: (kernel_height, kernel_width, in_channels,
- out_channels). Represents shape of kernel used in this layer.
- strides: The stride size in this layer (1-D Tensor of length 4).
- padding: The padding in this layer (1-D of Tensor length 4).
- data_format: None or str. Format of conv2d inputs.
- dilations: None or tuple of 4 ints.
- has_bias: Python bool. If True, the layer is assumed to have a bias
- parameter in addition to its filter parameter.
-
- Raises:
- ValueError: If inputs, output_grads, and filter_shape do not agree on
- in_channels or out_channels.
- ValueError: If strides, dilations are not length-4 lists of ints.
- ValueError: If data_format does not put channel last.
- """
- if not utils.is_data_format_channel_last(data_format):
- raise ValueError("Channel must be last.")
- if any(input_.shape.ndims != 4 for input_ in inputs):
- raise ValueError("inputs must be a list of 4-D Tensors.")
- if any(input_.shape.as_list()[-1] != filter_shape[-2] for input_ in inputs):
- raise ValueError("inputs and filter_shape must agree on in_channels.")
- for i, outputs_grad in enumerate(outputs_grads):
- if any(output_grad.shape.ndims != 4 for output_grad in outputs_grad):
- raise ValueError("outputs[%d] must be 4-D Tensor." % i)
- if any(output_grad.shape.as_list()[-1] != filter_shape[-1]
- for output_grad in outputs_grad):
- raise ValueError(
- "outputs[%d] and filter_shape must agree on out_channels." % i)
- if len(strides) != 4:
- raise ValueError("strides must be length-4 list of ints.")
- if dilations is not None and len(dilations) != 4:
- raise ValueError("dilations must be length-4 list of ints.")
-
- self._inputs = inputs
- self._outputs_grads = outputs_grads
- self._filter_shape = filter_shape
- self._strides = strides
- self._padding = padding
- self._data_format = data_format
- self._dilations = dilations
- self._has_bias = has_bias
- self._patches = None
-
- super(ConvDiagonalFactor, self).__init__()
-
- @property
- def _var_scope(self):
- return "ff_convdiag_" + scope_string_from_params(
- tuple(self._inputs) + tuple(nest.flatten(self._outputs_grads)))
-
- @property
- def _cov_shape(self):
- filter_height, filter_width, in_channels, out_channels = self._filter_shape
- return [
- filter_height * filter_width * in_channels + self._has_bias,
- out_channels
- ]
-
- @property
- def _num_sources(self):
- return len(self._outputs_grads)
-
- @property
- def _num_towers(self):
- return len(self._inputs)
-
- @property
- def _dtype(self):
- return self._inputs[0].dtype
-
- def make_covariance_update_op(self, ema_decay):
- filter_height, filter_width, _, _ = self._filter_shape
-
- # TODO(b/64144716): there is potential here for a big savings in terms
- # of memory use.
- if self._dilations is None:
- rates = (1, 1, 1, 1)
- else:
- rates = tuple(self._dilations)
-
- self._patches = []
- for tower in range(self._num_towers):
- with place_on_device(self._get_data_device(tower)):
- patches = array_ops.extract_image_patches(
- self._inputs[tower],
- ksizes=[1, filter_height, filter_width, 1],
- strides=self._strides,
- rates=rates,
- padding=self._padding)
-
- if self._has_bias:
- patches = append_homog(patches)
-
- self._patches.append(patches)
-
- return super(ConvDiagonalFactor, self).make_covariance_update_op(ema_decay)
-
- def _compute_new_cov(self, source, tower):
- patches = self._patches[tower]
- batch_size = array_ops.shape(patches)[0]
- outputs_grad = self._outputs_grads[source][tower]
-
- new_cov = self._convdiag_sum_of_squares(patches, outputs_grad)
- new_cov /= math_ops.cast(batch_size, new_cov.dtype)
-
- return new_cov
-
- def _convdiag_sum_of_squares(self, patches, outputs_grad):
- # This computes the sum of the squares of the per-training-case "gradients".
- # It does this simply by computing a giant tensor containing all of these,
- # doing an entry-wise square, and them summing along the batch dimension.
- case_wise_gradients = special_math_ops.einsum("bijk,bijl->bkl", patches,
- outputs_grad)
- return math_ops.reduce_sum(math_ops.square(case_wise_gradients), axis=0)
-
- def _get_data_device(self, tower):
- return self._inputs[tower].device
-
-
-class FullyConnectedKroneckerFactor(DenseSquareMatrixFactor):
- """Kronecker factor for the input or output side of a fully-connected layer.
- """
-
- def __init__(self,
- tensors,
- has_bias=False):
- """Instantiate FullyConnectedKroneckerFactor.
-
- Args:
- tensors: List of list of Tensors, each of shape [batch_size, n]. The
- Tensors are typically either a layer's inputs or its output's gradients.
- The first list index is source, the second is tower.
- has_bias: bool. If True, append '1' to each row.
- """
- # The tensor argument is either a tensor of input activations or a tensor of
- # output pre-activation gradients.
- self._has_bias = has_bias
- self._tensors = tensors
- super(FullyConnectedKroneckerFactor, self).__init__()
-
- @property
- def _var_scope(self):
- return "ff_fckron_" + scope_string_from_params(
- tuple(nest.flatten(self._tensors)) + (self._has_bias,))
-
- @property
- def _cov_shape(self):
- size = self._tensors[0][0].shape[1] + self._has_bias
- return [size, size]
-
- @property
- def _num_sources(self):
- return len(self._tensors)
-
- @property
- def _num_towers(self):
- return len(self._tensors[0])
-
- @property
- def _dtype(self):
- return self._tensors[0][0].dtype
-
- def _compute_new_cov(self, source, tower):
- tensor = self._tensors[source][tower]
- if self._has_bias:
- tensor = append_homog(tensor)
- return compute_cov(tensor)
-
- def _get_data_device(self, tower):
- return self._tensors[0][tower].device
-
-
-class ConvInputKroneckerFactor(DenseSquareMatrixFactor):
- r"""Kronecker factor for the input side of a convolutional layer.
-
- Estimates E[ a a^T ] where a is the inputs to a convolutional layer given
- example x. Expectation is taken over all examples and locations.
-
- Equivalent to Omega in https://arxiv.org/abs/1602.01407 for details. See
- Section 3.1 Estimating the factors.
- """
-
- def __init__(self,
- inputs,
- filter_shape,
- padding,
- strides=None,
- dilation_rate=None,
- data_format=None,
- extract_patches_fn=None,
- has_bias=False,
- sub_sample_inputs=None,
- sub_sample_patches=None):
- """Initializes ConvInputKroneckerFactor.
-
- Args:
- inputs: List of Tensors of shape [batch_size, ..spatial_input_size..,
- in_channels]. Inputs to layer. List index is tower.
- filter_shape: List of ints. Contains [..spatial_filter_size..,
- in_channels, out_channels]. Shape of convolution kernel.
- padding: str. Padding method for layer. "SAME" or "VALID".
- strides: List of ints or None. Contains [..spatial_filter_strides..] if
- 'extract_patches_fn' is compatible with tf.nn.convolution(), else
- [1, ..spatial_filter_strides, 1].
- dilation_rate: List of ints or None. Rate for dilation along each spatial
- dimension if 'extract_patches_fn' is compatible with
- tf.nn.convolution(), else [1, ..spatial_dilation_rates.., 1].
- data_format: str or None. Format of input data.
- extract_patches_fn: str or None. Name of function that extracts image
- patches. One of "extract_convolution_patches", "extract_image_patches",
- "extract_pointwise_conv2d_patches".
- has_bias: bool. If True, append 1 to in_channel.
- sub_sample_inputs: `bool`. If True, then subsample the inputs from which
- the image patches are extracted. (Default: None)
- sub_sample_patches: `bool`, If `True` then subsample the extracted
- patches.(Default: None)
- """
- self._inputs = inputs
- self._filter_shape = filter_shape
- self._strides = strides
- self._padding = padding
- self._dilation_rate = dilation_rate
- self._data_format = data_format
- self._extract_patches_fn = extract_patches_fn
- self._has_bias = has_bias
- if sub_sample_inputs is None:
- self._sub_sample_inputs = _SUB_SAMPLE_INPUTS
- else:
- self._sub_sample_inputs = sub_sample_inputs
-
- if sub_sample_patches is None:
- self._sub_sample_patches = _SUB_SAMPLE_OUTER_PRODUCTS
- else:
- self._sub_sample_patches = sub_sample_patches
- super(ConvInputKroneckerFactor, self).__init__()
-
- @property
- def _var_scope(self):
- return "ff_convinkron_" + scope_string_from_params(
- tuple(self._inputs) +
- tuple((self._filter_shape, self._strides, self._padding,
- self._dilation_rate, self._data_format, self._has_bias)))
-
- @property
- def _cov_shape(self):
- spatial_filter_shape = self._filter_shape[0:-2]
- in_channels = self._filter_shape[-2]
- size = np.prod(spatial_filter_shape) * in_channels + self._has_bias
- return [size, size]
-
- @property
- def _num_sources(self):
- return 1
-
- @property
- def _num_towers(self):
- return len(self._inputs)
-
- @property
- def _dtype(self):
- return self._inputs[0].dtype
-
- def _compute_new_cov(self, source, tower):
- assert source == 0
-
- inputs = self._inputs[tower]
- if self._sub_sample_inputs:
- batch_size = inputs.shape.as_list()[0]
- max_size = int(batch_size * _INPUTS_TO_EXTRACT_PATCHES_FACTOR)
- inputs = _random_tensor_gather(inputs, max_size)
-
- # TODO(b/64144716): there is potential here for a big savings in terms of
- # memory use.
- if self._extract_patches_fn in [None, "extract_convolution_patches"]:
- patches = utils.extract_convolution_patches(
- inputs,
- self._filter_shape,
- padding=self._padding,
- strides=self._strides,
- dilation_rate=self._dilation_rate,
- data_format=self._data_format)
-
- elif self._extract_patches_fn == "extract_image_patches":
- assert inputs.shape.ndims == 4
- assert len(self._filter_shape) == 4
- assert len(self._strides) == 4, self._strides
- if self._dilation_rate is None:
- rates = [1, 1, 1, 1]
- else:
- rates = self._dilation_rate
- assert len(rates) == 4
- assert rates[0] == rates[-1] == 1
- patches = array_ops.extract_image_patches(
- inputs,
- ksizes=[1] + list(self._filter_shape[0:-2]) + [1],
- strides=self._strides,
- rates=rates,
- padding=self._padding)
-
- elif self._extract_patches_fn == "extract_pointwise_conv2d_patches":
- assert self._strides in [None, [1, 1, 1, 1], (1, 1, 1, 1)]
- assert self._filter_shape[0] == self._filter_shape[1] == 1
- patches = utils.extract_pointwise_conv2d_patches(
- inputs, self._filter_shape, data_format=None)
-
- else:
- raise NotImplementedError(self._extract_patches_fn)
-
- flatten_size = np.prod(self._filter_shape[0:-1])
- # patches_flat below is the matrix [[A_l]] from the KFC paper (tilde
- # omitted over A for clarity). It has shape M|T| x J|Delta| (eq. 14),
- # where M = minibatch size, |T| = number of spatial locations,
- # |Delta| = number of spatial offsets, and J = number of input maps
- # for convolutional layer l.
- patches_flat = array_ops.reshape(patches, [-1, flatten_size])
-
- # We append a homogenous coordinate to patches_flat if the layer has
- # bias parameters. This gives us [[A_l]]_H from the paper.
- if self._sub_sample_patches:
- patches_flat = _subsample_for_cov_computation(patches_flat)
-
- if self._has_bias:
- patches_flat = append_homog(patches_flat)
- # We call compute_cov without passing in a normalizer. compute_cov uses
- # the first dimension of patches_flat i.e. M|T| as the normalizer by
- # default. Hence we end up computing 1/M|T| * [[A_l]]^T [[A_l]], with
- # shape J|Delta| x J|Delta|. This is related to hat{Omega}_l from
- # the paper but has a different scale here for consistency with
- # ConvOutputKroneckerFactor.
- # (Tilde omitted over A for clarity.)
- return compute_cov(patches_flat)
-
- def _get_data_device(self, tower):
- return self._inputs[tower].device
-
-
-class ConvOutputKroneckerFactor(DenseSquareMatrixFactor):
- r"""Kronecker factor for the output side of a convolutional layer.
-
- Estimates E[ ds ds^T ] where s is the preactivations of a convolutional layer
- given example x and ds = (d / d s) log(p(y|x, w)). Expectation is taken over
- all examples and locations.
-
- Equivalent to Gamma in https://arxiv.org/abs/1602.01407 for details. See
- Section 3.1 Estimating the factors.
- """
-
- def __init__(self, outputs_grads, data_format=None):
- """Initializes ConvOutputKroneckerFactor.
-
- Args:
- outputs_grads: List of list of Tensors. Each Tensor is of shape
- [batch_size, ..spatial_input_size.., out_channels]. First list index
- is source, the second is tower.
- data_format: None or str. Format of outputs_grads.
-
- Raises:
- ValueError: If channels are not final dimension.
- """
- if not utils.is_data_format_channel_last(data_format):
- raise ValueError("Channel must be last.")
- self._out_channels = outputs_grads[0][0].shape.as_list()[-1]
- self._outputs_grads = outputs_grads
- super(ConvOutputKroneckerFactor, self).__init__()
-
- @property
- def _var_scope(self):
- return "ff_convoutkron_" + scope_string_from_params(
- nest.flatten(self._outputs_grads))
-
- @property
- def _cov_shape(self):
- size = self._out_channels
- return [size, size]
-
- @property
- def _num_sources(self):
- return len(self._outputs_grads)
-
- @property
- def _num_towers(self):
- return len(self._outputs_grads[0])
-
- @property
- def _dtype(self):
- return self._outputs_grads[0][0].dtype
-
- def _compute_new_cov(self, source, tower):
- outputs_grad = self._outputs_grads[source][tower]
-
- # reshaped_tensor below is the matrix DS_l defined in the KFC paper
- # (tilde omitted over S for clarity). It has shape M|T| x I, where
- # M = minibatch size, |T| = number of spatial locations, and
- # I = number of output maps for convolutional layer l.
- reshaped_tensor = array_ops.reshape(outputs_grad, [-1, self._out_channels])
- # Following the reasoning in ConvInputKroneckerFactor._compute_new_cov,
- # compute_cov here returns 1/M|T| * DS_l^T DS_l = hat{Gamma}_l
- # as defined in the paper, with shape I x I.
- # (Tilde omitted over S for clarity.)
- return compute_cov(reshaped_tensor)
-
- def _get_data_device(self, tower):
- return self._outputs_grads[0][tower].device
-
-
-class FullyConnectedMultiKF(FullyConnectedKroneckerFactor):
- """Kronecker factor for a fully connected layer used multiple times."""
-
- def __init__(self,
- tensors,
- num_uses=None,
- has_bias=False):
- """Constructs a new `FullyConnectedMultiKF`.
-
- Args:
- tensors: List of list of Tensors of shape, each of shape
- [num_uses * batch_size, n], and is a reshape version of a Tensor of
- shape [num_uses, batch_size, n]. Each of these tensors is usually a
- layer's inputs or its output's gradients. The first list index is
- sources, the second is towers.
- num_uses: int. The number of time-steps / uses.
- has_bias: bool. If True, '1' is appended to each row.
- """
-
- self._num_uses = num_uses
-
- self._cov_dt1 = None
- self._make_cov_dt1 = False
- self._option1quants_by_damping = {}
- self._option2quants_by_damping = {}
- self._option1quants_registrations = set()
- self._option2quants_registrations = set()
-
- super(FullyConnectedMultiKF, self).__init__(tensors=tensors,
- has_bias=has_bias)
-
- @property
- def _num_timesteps(self):
- return self._num_uses
-
- @property
- def _var_scope(self):
- return "ff_fc_multi_" + scope_string_from_params(
- tuple(nest.flatten(self._tensors))
- + (self._num_timesteps, self._has_bias,))
-
- def make_covariance_update_op(self, ema_decay):
-
- op = super(FullyConnectedMultiKF, self).make_covariance_update_op(ema_decay)
-
- if self._cov_dt1 is not None:
- new_cov_dt1_contribs = []
- for source in range(self._num_sources):
- for tower in range(self._num_towers):
- with place_on_device(self._get_data_device(tower)):
- new_cov_dt1_contribs.append(self._compute_new_cov_dt1(source,
- tower))
-
- new_cov_dt1 = (math_ops.add_n(new_cov_dt1_contribs)
- / float(self._num_towers))
-
- # See comments in FisherFactor.make_covariance_update_op() for details.
- if utils.on_tpu():
- new_cov_dt1 = utils.cross_replica_mean(new_cov_dt1)
-
- op2 = moving_averages.assign_moving_average(
- self._cov_dt1, new_cov_dt1, ema_decay, zero_debias=ZERO_DEBIAS)
-
- # TODO(b/69112164):
- # It's important that _cov and _cov_dt1 remain consistent with each
- # other while the inverse ops are happening. How can we ensure this?
- # We will need to add explicit synchronization for this to
- # work with asynchronous training.
- op = control_flow_ops.group(op, op2)
-
- return op
-
- def _compute_new_cov_dt1(self, source, tower): # pylint: disable=missing-docstring
- tensor = self._tensors[source][tower]
- if self._has_bias:
- # This appending is technically done twice (the other time is for
- # _compute_new_cov())
- tensor = append_homog(tensor)
-
- total_len = array_ops.shape(tensor)[0]
- batch_size = total_len // self._num_timesteps
-
- tensor_present = tensor[:-batch_size, :]
- tensor_future = tensor[batch_size:, :]
-
- # We specify a normalizer for this computation to ensure a PSD Fisher
- # block estimate. This is equivalent to padding with zeros, as was done
- # in Section B.2 of the appendix.
- return compute_cov(
- tensor_future, tensor_right=tensor_present, normalizer=total_len)
-
- def _get_data_device(self, tower):
- return self._tensors[0][tower].device
-
- @property
- def _vec_shape(self):
- size = self._tensors[0][0].shape[1] + self._has_bias
- return [size]
-
- def get_option1quants(self, damping_func):
- damping_id = graph_func_to_id(damping_func)
- return self._option1quants_by_damping[damping_id]
-
- def get_option2quants(self, damping_func):
- damping_id = graph_func_to_id(damping_func)
- return self._option2quants_by_damping[damping_id]
-
- def get_cov_dt1(self):
- assert self._cov_dt1 is not None
- return self._cov_dt1
-
- def register_cov_dt1(self):
- self._make_cov_dt1 = True
-
- def instantiate_cov_variables(self):
- super(FullyConnectedMultiKF, self).instantiate_cov_variables()
- assert self._cov_dt1 is None
- if self._make_cov_dt1:
- with variable_scope.variable_scope(self._var_scope):
- self._cov_dt1 = variable_scope.get_variable(
- "cov_dt1",
- initializer=init_ops.zeros_initializer,
- shape=self._cov_shape,
- trainable=False,
- dtype=self._dtype)
-
- def register_option1quants(self, damping_func):
- damping_id = self._register_damping(damping_func)
- if damping_id not in self._option1quants_registrations:
- self._option1quants_registrations.add(damping_id)
-
- def register_option2quants(self, damping_func):
- damping_id = self._register_damping(damping_func)
- if damping_id not in self._option2quants_registrations:
- self._option2quants_registrations.add(damping_id)
-
- def instantiate_inv_variables(self):
- super(FullyConnectedMultiKF, self).instantiate_inv_variables()
-
- for damping_id in self._option1quants_registrations:
- damping_func = self._damping_funcs_by_id[damping_id]
- damping_string = graph_func_to_string(damping_func)
- # It's questionable as to whether we should initialize with stuff like
- # this at all. Ideally these values should never be used until they are
- # updated at least once.
- with variable_scope.variable_scope(self._var_scope):
- Lmat = variable_scope.get_variable( # pylint: disable=invalid-name
- "Lmat_damp{}".format(damping_string),
- initializer=inverse_initializer,
- shape=self._cov_shape,
- trainable=False,
- dtype=self._dtype)
- psi = variable_scope.get_variable(
- "psi_damp{}".format(damping_string),
- initializer=init_ops.ones_initializer,
- shape=self._vec_shape,
- trainable=False,
- dtype=self._dtype)
-
- assert damping_id not in self._option1quants_by_damping
- self._option1quants_by_damping[damping_id] = (Lmat, psi)
-
- for damping_id in self._option2quants_registrations:
- damping_func = self._damping_funcs_by_id[damping_id]
- damping_string = graph_func_to_string(damping_func)
- # It's questionable as to whether we should initialize with stuff like
- # this at all. Ideally these values should never be used until they are
- # updated at least once.
- with variable_scope.variable_scope(self._var_scope):
- Pmat = variable_scope.get_variable( # pylint: disable=invalid-name
- "Lmat_damp{}".format(damping_string),
- initializer=inverse_initializer,
- shape=self._cov_shape,
- trainable=False,
- dtype=self._dtype)
- Kmat = variable_scope.get_variable( # pylint: disable=invalid-name
- "Kmat_damp{}".format(damping_string),
- initializer=inverse_initializer,
- shape=self._cov_shape,
- trainable=False,
- dtype=self._dtype)
- mu = variable_scope.get_variable(
- "mu_damp{}".format(damping_string),
- initializer=init_ops.ones_initializer,
- shape=self._vec_shape,
- trainable=False,
- dtype=self._dtype)
-
- assert damping_id not in self._option2quants_by_damping
- self._option2quants_by_damping[damping_id] = (Pmat, Kmat, mu)
-
- def make_inverse_update_ops(self):
- """Create and return update ops corresponding to registered computations."""
- # TODO(b/69918258): Add correctness tests for this method.
- # pylint: disable=invalid-name
-
- ops = []
-
- if (len(self._option1quants_by_damping) +
- len(self._option2quants_by_damping)):
-
- # Note that C0 and C1 are stand-ins for A0 and A1, or G0 and G1, from
- # the pseudo-code in the original paper. Because the computations for
- # the A and G case are essentially the same they can both be performed by
- # the same class (this one).
-
- C1 = self.get_cov_dt1()
-
- # Get the eigendecomposition of C0 (= self.get_cov())
- eigen_e, eigen_V = self.get_eigendecomp()
-
- # TODO(b/69678661): Note, there is an implicit assumption here that C1
- # and C0 (as represented here by its eigen-decomp) are consistent. This
- # could fail to be the case if self._cov and self._cov_dt1 are not updated
- # consistently, or are somehow read between or during the cov updates.
- # Can this possibly happen? Is there a way to prevent it?
-
- for damping_id, (Lmat_var,
- psi_var) in self._option1quants_by_damping.items():
-
- damping = self._damping_funcs_by_id[damping_id]()
- damping = math_ops.cast(damping, self._dtype)
-
- invsqrtC0 = math_ops.matmul(
- eigen_V * (eigen_e + damping)**(-0.5), eigen_V, transpose_b=True)
-
- # Might need to enforce symmetry lost due to numerical issues.
- invsqrtC0 = (invsqrtC0 + array_ops.transpose(invsqrtC0)) / 2.0
-
- # The following line imposes the symmetry assumed by "Option 1" on C1.
- # Strangely the code can work okay with this line commented out,
- # depending on how psd_eig is defined. I'm not sure why.
- C1 = (C1 + array_ops.transpose(C1)) / 2.0
-
- # hPsi = C0^(-1/2) * C1 * C0^(-1/2) (hPsi means hat{Psi})
- hPsi = math_ops.matmul(math_ops.matmul(invsqrtC0, C1), invsqrtC0)
-
- # Compute the decomposition U*diag(psi)*U^T = hPsi
- psi, U = utils.posdef_eig(hPsi)
-
- # L = C0^(-1/2) * U
- Lmat = math_ops.matmul(invsqrtC0, U)
-
- ops.append(Lmat_var.assign(Lmat))
- ops.append(psi_var.assign(psi))
-
- for damping_id, (Pmat_var, Kmat_var,
- mu_var) in self._option2quants_by_damping.items():
-
- damping = self._damping_funcs_by_id[damping_id]()
- damping = math_ops.cast(damping, self._dtype)
-
- # compute C0^(-1/2)
- invsqrtC0 = math_ops.matmul(
- eigen_V * (eigen_e + damping)**(-0.5), eigen_V, transpose_b=True)
-
- # Might need to enforce symmetry lost due to numerical issues.
- invsqrtC0 = (invsqrtC0 + array_ops.transpose(invsqrtC0)) / 2.0
-
- # Compute the product C0^(-1/2) * C1
- invsqrtC0C1 = math_ops.matmul(invsqrtC0, C1)
-
- # hPsi = C0^(-1/2) * C1 * C0^(-1/2) (hPsi means hat{Psi})
- hPsi = math_ops.matmul(invsqrtC0C1, invsqrtC0)
-
- # Compute the decomposition E*diag(mu)*E^T = hPsi^T * hPsi
- # Note that we using the notation mu instead of "m" for the eigenvalues.
- # Instead of computing the product hPsi^T * hPsi and then doing an
- # eigen-decomposition of this we just compute the SVD of hPsi and then
- # square the singular values to get the eigenvalues. For a justification
- # of this approach, see:
- # https://en.wikipedia.org/wiki/Singular-value_decomposition#Relation_to_eigenvalue_decomposition
- sqrtmu, _, E = linalg_ops.svd(hPsi)
- mu = math_ops.square(sqrtmu)
-
- # Mathematically, the eigenvalues should not should not exceed 1.0, but
- # due to numerical issues, or possible issues with inconsistent
- # values of C1 and (the eigen-decomposition of) C0 they might. So
- # we enforce this condition.
- mu = math_ops.minimum(mu, 1.0)
-
- # P = (C0^(-1/2) * C1)^T * C0^(-1/2) = C_1^T * C_0^(-1)
- Pmat = math_ops.matmul(invsqrtC0C1, invsqrtC0, transpose_a=True)
-
- # K = C_0^(-1/2) * E
- Kmat = math_ops.matmul(invsqrtC0, E)
-
- ops.append(Pmat_var.assign(Pmat))
- ops.append(Kmat_var.assign(Kmat))
- ops.append(mu_var.assign(mu))
-
- ops += super(FullyConnectedMultiKF, self).make_inverse_update_ops()
- return [control_flow_ops.group(*ops)]
-
- # pylint: enable=invalid-name
diff --git a/tensorflow/contrib/kfac/python/ops/fisher_factors_lib.py b/tensorflow/contrib/kfac/python/ops/fisher_factors_lib.py
deleted file mode 100644
index 2d8e378a93..0000000000
--- a/tensorflow/contrib/kfac/python/ops/fisher_factors_lib.py
+++ /dev/null
@@ -1,38 +0,0 @@
-# Copyright 2017 The TensorFlow Authors. All Rights Reserved.
-#
-# Licensed under the Apache License, Version 2.0 (the "License");
-# you may not use this file except in compliance with the License.
-# You may obtain a copy of the License at
-#
-# http://www.apache.org/licenses/LICENSE-2.0
-#
-# Unless required by applicable law or agreed to in writing, software
-# distributed under the License is distributed on an "AS IS" BASIS,
-# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-# See the License for the specific language governing permissions and
-# limitations under the License.
-# ==============================================================================
-"""FisherFactor definitions."""
-
-from __future__ import absolute_import
-from __future__ import division
-from __future__ import print_function
-
-# pylint: disable=unused-import,line-too-long,wildcard-import
-from tensorflow.contrib.kfac.python.ops.fisher_factors import *
-from tensorflow.python.util.all_util import remove_undocumented
-# pylint: enable=unused-import,line-too-long,wildcard-import
-
-_allowed_symbols = [
- "inverse_initializer", "covariance_initializer",
- "diagonal_covariance_initializer", "scope_string_from_params",
- "scope_string_from_name", "scalar_or_tensor_to_string", "FisherFactor",
- "InverseProvidingFactor", "FullFactor", "DiagonalFactor",
- "NaiveDiagonalFactor", "EmbeddingInputKroneckerFactor",
- "FullyConnectedDiagonalFactor", "FullyConnectedKroneckerFactor",
- "ConvInputKroneckerFactor", "ConvOutputKroneckerFactor",
- "ConvDiagonalFactor", "set_global_constants", "maybe_colocate_with",
- "compute_cov", "append_homog"
-]
-
-remove_undocumented(__name__, allowed_exception_list=_allowed_symbols)
diff --git a/tensorflow/contrib/kfac/python/ops/layer_collection.py b/tensorflow/contrib/kfac/python/ops/layer_collection.py
deleted file mode 100644
index 43aa713edc..0000000000
--- a/tensorflow/contrib/kfac/python/ops/layer_collection.py
+++ /dev/null
@@ -1,1269 +0,0 @@
-# Copyright 2017 The TensorFlow Authors. All Rights Reserved.
-#
-# Licensed under the Apache License, Version 2.0 (the "License");
-# you may not use this file except in compliance with the License.
-# You may obtain a copy of the License at
-#
-# http://www.apache.org/licenses/LICENSE-2.0
-#
-# Unless required by applicable law or agreed to in writing, software
-# distributed under the License is distributed on an "AS IS" BASIS,
-# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-# See the License for the specific language governing permissions and
-# limitations under the License.
-# ==============================================================================
-"""Registry for layers and their parameters/variables.
-
-This represents the collection of all layers in the approximate Fisher
-information matrix to which a particular FisherBlock may belong. That is, we
-might have several layer collections for one TF graph (if we have multiple K-FAC
-optimizers being used, for example.)
-"""
-
-from __future__ import absolute_import
-from __future__ import division
-from __future__ import print_function
-
-from collections import defaultdict
-from collections import OrderedDict
-from contextlib import contextmanager
-from functools import partial
-import warnings
-
-import math
-import six
-
-from tensorflow.contrib.kfac.python.ops import fisher_blocks as fb
-from tensorflow.contrib.kfac.python.ops import loss_functions as lf
-from tensorflow.contrib.kfac.python.ops import utils
-from tensorflow.python.framework import ops
-from tensorflow.python.ops import math_ops
-from tensorflow.python.ops import variable_scope
-from tensorflow.python.util import nest
-
-# Names for various approximations that can be requested for Fisher blocks.
-APPROX_KRONECKER_NAME = "kron"
-APPROX_DIAGONAL_NAME = "diagonal"
-APPROX_FULL_NAME = "full"
-
-_GENERIC_APPROX_TO_BLOCK_TYPES = {
- APPROX_FULL_NAME: fb.FullFB,
- APPROX_DIAGONAL_NAME: fb.NaiveDiagonalFB,
-}
-
-_FULLY_CONNECTED_APPROX_TO_BLOCK_TYPES = {
- APPROX_KRONECKER_NAME: fb.FullyConnectedKFACBasicFB,
- APPROX_DIAGONAL_NAME: fb.FullyConnectedDiagonalFB,
-}
-
-_CONV2D_APPROX_TO_BLOCK_TYPES = {
- APPROX_KRONECKER_NAME: fb.ConvKFCBasicFB,
- APPROX_DIAGONAL_NAME: fb.ConvDiagonalFB,
-}
-
-_EMBEDDING_APPROX_TO_BLOCK_TYPES = {
- APPROX_KRONECKER_NAME: fb.EmbeddingKFACFB
-}
-
-APPROX_KRONECKER_INDEP_NAME = "kron_indep"
-APPROX_KRONECKER_SERIES_1_NAME = "kron_series_1"
-APPROX_KRONECKER_SERIES_2_NAME = "kron_series_2"
-
-_FULLY_CONNECTED_MULTI_APPROX_TO_BLOCK_TYPES = {
- APPROX_KRONECKER_INDEP_NAME: fb.FullyConnectedMultiIndepFB,
- APPROX_KRONECKER_SERIES_1_NAME: partial(fb.FullyConnectedSeriesFB,
- option=1),
- APPROX_KRONECKER_SERIES_2_NAME: partial(fb.FullyConnectedSeriesFB,
- option=2)
-}
-
-_CONV2D_MULTI_APPROX_TO_BLOCK_TYPES = {
- APPROX_KRONECKER_INDEP_NAME: fb.ConvKFCBasicMultiIndepFB
-}
-
-_EMBEDDING_MULTI_APPROX_TO_BLOCK_TYPES = {
- APPROX_KRONECKER_INDEP_NAME: fb.EmbeddingKFACMultiIndepFB
-}
-
-# Possible value for `reuse` keyword argument. Sets `reuse` to
-# tf.get_variable_scope().reuse.
-VARIABLE_SCOPE = "VARIABLE_SCOPE"
-
-_DEFAULT_LAYER_COLLECTION = None
-
-
-def get_default_layer_collection():
- """Get default LayerCollection."""
- if _DEFAULT_LAYER_COLLECTION is None:
- raise ValueError(
- "Attempted to retrieve default LayerCollection when none is set. Use "
- "LayerCollection.as_default().")
-
- return _DEFAULT_LAYER_COLLECTION
-
-
-def set_default_layer_collection(layer_collection):
- global _DEFAULT_LAYER_COLLECTION
-
- if _DEFAULT_LAYER_COLLECTION is not None and layer_collection is not None:
- raise ValueError("Default LayerCollection is already set.")
-
- _DEFAULT_LAYER_COLLECTION = layer_collection
-
-
-class LayerParametersDict(OrderedDict):
- """An OrderedDict where keys are Tensors or tuples of Tensors.
-
- Ensures that no Tensor is associated with two different keys.
- """
-
- def __init__(self, *args, **kwargs):
- self._tensors = set()
- super(LayerParametersDict, self).__init__(*args, **kwargs)
-
- def __setitem__(self, key, value):
- key = self._canonicalize_key(key)
- tensors = key if isinstance(key, (tuple, list)) else (key,)
- key_collisions = self._tensors.intersection(tensors)
- if key_collisions:
- raise ValueError("Key(s) already present: {}".format(key_collisions))
- self._tensors.update(tensors)
- super(LayerParametersDict, self).__setitem__(key, value)
-
- def __delitem__(self, key):
- key = self._canonicalize_key(key)
- self._tensors.remove(key)
- super(LayerParametersDict, self).__delitem__(key)
-
- def __getitem__(self, key):
- key = self._canonicalize_key(key)
- return super(LayerParametersDict, self).__getitem__(key)
-
- def __contains__(self, key):
- key = self._canonicalize_key(key)
- return super(LayerParametersDict, self).__contains__(key)
-
- def _canonicalize_key(self, key):
- if isinstance(key, (list, tuple)):
- return tuple(key)
- return key
-
-
-# TODO(b/68034464): add capability for LayerCollection to be "finalized"
-# and do this when it gets used by FisherEstimator / KfacOptimizer.
-
-
-class LayerCollection(object):
- """Registry of information about layers and losses.
-
- Note that you need to create a new one of these for each MatrixEstimator or
- KfacOptimizer.
-
- Attributes:
- fisher_blocks: a LayersParamsDict (subclass of OrderedDict) mapping layer
- parameters (Tensors or tuples of Tensors) to FisherBlock instances.
- fisher_factors: an OrderedDict mapping tuples to FisherFactor instances.
- losses: a list of LossFunction objects. The loss to be optimized is their
- sum.
- loss_colocation_ops: ops to colocate loss function evaluations with. These
- will typically be the inputs to the losses.
- """
-
- def __init__(self,
- graph=None,
- name="LayerCollection"):
- warnings.warn(
- "tf.contrib.kfac is deprecated and will be removed by 2018-11-01. "
- "Use https://pypi.python.org/pypi/kfac instead.")
- self.fisher_blocks = LayerParametersDict()
- self.fisher_factors = OrderedDict()
- self._linked_parameters = dict(
- ) # dict mapping sets of variables to optionally specified approximations.
- self._graph = graph or ops.get_default_graph()
- self._loss_dict = {} # {str: LossFunction}
- self._subgraph = None
- self._default_generic_approximation = APPROX_DIAGONAL_NAME
- self._default_embedding_approximation = APPROX_KRONECKER_NAME
- self._default_fully_connected_approximation = APPROX_KRONECKER_NAME
- self._default_conv2d_approximation = APPROX_KRONECKER_NAME
- self._default_fully_connected_multi_approximation = (
- APPROX_KRONECKER_INDEP_NAME)
- self._default_conv2d_multi_approximation = (
- APPROX_KRONECKER_INDEP_NAME)
- self._default_embedding_multi_approximation = APPROX_KRONECKER_INDEP_NAME
- self.loss_colocation_ops = {}
- self._vars_to_uses = defaultdict(lambda: 0)
-
- with variable_scope.variable_scope(None, default_name=name) as scope:
- self._var_scope = scope.name
-
- @property
- def losses(self):
- """Tuple of LossFunction objects registered with this LayerCollection."""
- return nest.flatten(self.towers_by_loss)
-
- @property
- def towers_by_loss(self):
- """Tuple across losses of LossFunction objects registered to each tower."""
- return tuple(tuple(lst) for lst in self._loss_dict.values())
-
- @property
- def registered_variables(self):
- """A tuple of all of the variables currently registered."""
- tuple_of_tuples = (utils.ensure_sequence(key) for key, block
- in six.iteritems(self.fisher_blocks))
- flat_tuple = tuple(item for tuple_ in tuple_of_tuples for item in tuple_)
- return flat_tuple
-
- @property
- def linked_parameters(self):
- """Groups of parameters with an optionally specified approximation.
-
- Linked parameters can be added using `define_linked_parameters`.
- If an approximation is specified, then this approximation will be used
- when registering a layer with exactly these parameters, unless an
- approximation is specified when calling the registration function.
-
- Returns:
- A `dict` mapping tuples of parameters to an optional string.
- """
- return self._linked_parameters
-
- @property
- def default_embedding_approximation(self):
- return self._default_embedding_approximation
-
- def set_default_embedding_approximation(self, value):
- if value != APPROX_KRONECKER_NAME:
- raise ValueError(
- "{} is not a valid approximation for embedding variables.".format(
- value))
- self._default_embedding_approximation = value
-
- @property
- def default_generic_approximation(self):
- return self._default_generic_approximation
-
- def set_default_generic_approximation(self, value):
- if value not in _GENERIC_APPROX_TO_BLOCK_TYPES:
- raise ValueError(
- "{} is not a valid approximation for generic variables.".format(
- value))
- self._default_generic_approximation = value
-
- @property
- def default_fully_connected_approximation(self):
- return self._default_fully_connected_approximation
-
- def set_default_fully_connected_approximation(self, value):
- if value not in _FULLY_CONNECTED_APPROX_TO_BLOCK_TYPES:
- raise ValueError(
- "{} is not a valid approximation for fully connected layers.".format(
- value))
- self._default_fully_connected_approximation = value
-
- @property
- def default_conv2d_approximation(self):
- return self._default_conv2d_approximation
-
- def set_default_conv2d_approximation(self, value):
- if value not in _CONV2D_APPROX_TO_BLOCK_TYPES:
- raise ValueError(
- "{} is not a valid approximation for 2d convolutional layers.".format(
- value))
- self._default_conv2d_approximation = value
-
- @property
- def default_fully_connected_multi_approximation(self):
- return self._default_fully_connected_multi_approximation
-
- def set_default_fully_connected_multi_approximation(self, value):
- if value not in _FULLY_CONNECTED_MULTI_APPROX_TO_BLOCK_TYPES:
- raise ValueError("{} is not a valid approximation for a fully-connected "
- "multi layer.".format(value))
- self._default_fully_connected_multi_approximation = value
-
- @property
- def default_conv2d_multi_approximation(self):
- return self._default_conv2d_multi_approximation
-
- @property
- def default_embedding_multi_approximation(self):
- return self._default_embedding_multi_approximation
-
- def register_block(self, layer_key, fisher_block, reuse=VARIABLE_SCOPE):
- """Validates and registers the layer_key associated with the fisher_block.
-
- Args:
- layer_key: A variable or tuple of variables. The key to check for in
- existing registrations and to register if valid.
- fisher_block: The associated `FisherBlock`.
- reuse: Method to use for inserting new `FisherBlock's. One of True, False,
- or `VARIABLE_SCOPE`.
-
- Raises:
- ValueError: If `layer_key` was already registered and reuse is `False`,
- if `layer_key` was registered with a different block type, or if
- `layer_key` shares any variables with but is not equal to a previously
- registered key.
- KeyError: If `reuse` is `True` but `layer_key` was not previously
- registered.
-
- Returns:
- The `FisherBlock` registered under `layer_key`. If `layer_key` was already
- registered, this will be the previously registered `FisherBlock`.
- """
- if reuse is VARIABLE_SCOPE:
- reuse = variable_scope.get_variable_scope().reuse
-
- if reuse is True or (reuse is variable_scope.AUTO_REUSE and
- layer_key in self.fisher_blocks):
- result = self.fisher_blocks[layer_key]
- if type(result) != type(fisher_block): # pylint: disable=unidiomatic-typecheck
- raise ValueError(
- "Attempted to register FisherBlock of type %s when existing "
- "FisherBlock has type %s." % (type(fisher_block), type(result)))
- return result
- if reuse is False and layer_key in self.fisher_blocks:
- raise ValueError("FisherBlock for %s is already in LayerCollection." %
- (layer_key,))
-
- # Insert fisher_block into self.fisher_blocks.
- if layer_key in self.fisher_blocks:
- raise ValueError("Duplicate registration: {}".format(layer_key))
- # Raise an error if any variable in layer_key has been registered in any
- # other blocks.
- variable_to_block = {
- var: (params, block)
- for (params, block) in self.fisher_blocks.items()
- for var in utils.ensure_sequence(params)
- }
- for variable in utils.ensure_sequence(layer_key):
- if variable in variable_to_block:
- prev_key, prev_block = variable_to_block[variable]
- raise ValueError(
- "Attempted to register layer_key {} with block {}, but variable {}"
- " was already registered in key {} with block {}.".format(
- layer_key, fisher_block, variable, prev_key, prev_block))
- self.fisher_blocks[layer_key] = fisher_block
- return fisher_block
-
- def register_loss_function(self,
- loss,
- colocation_op,
- base_name,
- name=None,
- reuse=VARIABLE_SCOPE):
- """Registers a LossFunction object.
-
- Args:
- loss: The LossFunction object.
- colocation_op: The op to colocate the loss function's computations with.
- base_name: The name to derive a new unique name from is the name argument
- is None.
- name: (OPTIONAL) str or None. Unique name for this loss function. If None,
- a new name is generated. (Default: None)
- reuse: (OPTIONAL) bool or str. If True, adds `loss` as an additional
- tower for the existing loss function.
-
- Raises:
- ValueError: If reuse == True and name == None.
- ValueError: If reuse == True and seed != None.
- KeyError: If reuse == True and no existing LossFunction with `name` found.
- KeyError: If reuse == False and existing LossFunction with `name` found.
- """
-
- name = name or self._graph.unique_name(base_name)
-
- if reuse == VARIABLE_SCOPE:
- reuse = variable_scope.get_variable_scope().reuse
-
- if reuse:
- if name is None:
- raise ValueError(
- "If reuse is enabled, loss function's name must be set.")
-
- loss_list = self._loss_dict.get(name, None)
-
- if loss_list is None:
- raise KeyError(
- "Unable to find loss function named {}. Register a new loss "
- "function with reuse=False.".format(name))
- else:
- if name in self._loss_dict:
- raise KeyError(
- "Loss function named {} already exists. Set reuse=True to append "
- "another tower.".format(name))
-
- loss_list = []
- self._loss_dict[name] = loss_list
-
- loss_list.append(loss)
- self.loss_colocation_ops[loss] = colocation_op
-
- def _get_use_count_map(self):
- """Returns a dict mapping variables to their number of registrations."""
- return self._vars_to_uses
-
- def _add_uses(self, params, uses):
- """Register additional uses by params in the graph.
-
- Args:
- params: Variable or tuple of Variables. Parameters for a layer.
- uses: int or float. Number of additional uses for these parameters.
- """
- params = params if isinstance(params, (tuple, list)) else (params,)
- for var in params:
- self._vars_to_uses[var] += uses
-
- def check_registration(self, variables):
- """Checks that all variable uses have been registered properly.
-
- Args:
- variables: List of variables.
-
- Raises:
- ValueError: If any registered variables are not included in the list.
- ValueError: If any variable in the list is not registered.
- ValueError: If any variable in the list is registered with the wrong
- number of "uses" in the subgraph recorded (vs the number of times that
- variable is actually used in the subgraph).
- """
- # Note that overlapping parameters (i.e. those that share variables) will
- # be caught by layer_collection.LayerParametersDict during registration.
-
- reg_use_map = self._get_use_count_map()
-
- error_messages = []
-
- for var in variables:
- total_uses = self.subgraph.variable_uses(var)
- reg_uses = reg_use_map[var]
-
- if reg_uses == 0:
- error_messages.append("Variable {} not registered.".format(var))
- elif (not math.isinf(reg_uses)) and reg_uses != total_uses:
- error_messages.append(
- "Variable {} registered with wrong number of uses ({} "
- "registrations vs {} uses).".format(var, reg_uses, total_uses))
-
- num_get_vars = len(reg_use_map)
-
- if num_get_vars > len(variables):
- error_messages.append("{} registered variables were not included in list."
- .format(num_get_vars - len(variables)))
-
- if error_messages:
- error_messages = [
- "Found the following errors with variable registration:"
- ] + error_messages
- raise ValueError("\n\t".join(error_messages))
-
- def get_blocks(self):
- return self.fisher_blocks.values()
-
- def get_factors(self):
- return self.fisher_factors.values()
-
- @property
- def graph(self):
- return self._graph
-
- @property
- def subgraph(self):
- return self._subgraph
-
- def define_linked_parameters(self, params, approximation=None):
- """Identify a set of parameters that should be grouped together.
-
- During automatic graph scanning, any matches containing variables that have
- been identified as part of a linked group will be filtered out unless
- the match parameters are exactly equal to the ones specified in the linked
- group.
-
- Args:
- params: A variable, or a tuple or list of variables. The variables
- to be linked.
- approximation: Optional string specifying the type of approximation to use
- for these variables. If unspecified, this layer collection's default
- approximation for the layer type will be used.
-
- Raises:
- ValueError: If the parameters were already registered in a layer or
- identified as part of an incompatible group.
- """
- params = frozenset(utils.ensure_sequence(params))
-
- # Check if any of the variables in `params` is already in
- # 'self.fisher_blocks.keys()`.
- for registered_params, fisher_block in self.fisher_blocks.items():
- registered_params_set = set(utils.ensure_sequence(registered_params))
- for variable in params:
- if (variable in registered_params_set and
- params != registered_params_set):
- raise ValueError(
- "Can`t link parameters {}, variable {} was already registered in "
- "group {} with layer {}".format(params, variable,
- registered_params, fisher_block))
-
- # Check if any of the variables in `params` is already in
- # 'self.linked_parameters`.
- for variable in params:
- for other_linked_params in self.linked_parameters:
- if variable in other_linked_params:
- raise ValueError("Can`t link parameters {}, variable {} was already "
- "linked in group {}.".format(params, variable,
- other_linked_params))
- self._linked_parameters[params] = approximation
-
- def create_subgraph(self):
- if not self.losses:
- raise ValueError("Must have at least one registered loss.")
- inputs_to_losses = nest.flatten(tuple(loss.inputs for loss in self.losses))
- self._subgraph = utils.SubGraph(inputs_to_losses)
-
- def eval_losses(self):
- """Return evaluated losses (colocated with inputs to losses)."""
- evals = []
- for loss in self.losses:
- with ops.colocate_with(self.loss_colocation_ops[loss]):
- evals.append(loss.evaluate())
- return evals
-
- def eval_losses_on_samples(self):
- """Return losses evaluated on samples (colocated with inputs to losses)."""
- evals = []
- for loss in self.losses:
- with ops.colocate_with(self.loss_colocation_ops[loss]):
- evals.append(loss.evaluate_on_sample())
- return evals
-
- def total_loss(self):
- return math_ops.add_n(self.eval_losses())
-
- def total_sampled_loss(self):
- return math_ops.add_n(self.eval_losses_on_samples())
-
- def _get_linked_approx(self, params):
- """If params were linked, return their specified approximation."""
- params_set = frozenset(utils.ensure_sequence(params))
- if params_set in self.linked_parameters:
- return self.linked_parameters[params_set]
- else:
- return None
-
- def _get_block_type(self, params, approx, default, approx_to_type):
- if approx is None:
- approx = self._get_linked_approx(params)
- if approx is None:
- approx = default
-
- if approx not in approx_to_type:
- raise ValueError("Bad value {} for approx.".format(approx))
-
- return approx_to_type[approx], approx
-
- def register_embedding(self,
- params,
- inputs,
- outputs,
- approx=None,
- reuse=VARIABLE_SCOPE):
- """Registers an embedding layer.
-
- Args:
- params: Embedding matrix of shape [vocab_size, embedding_size].
- inputs: Tensor of shape [batch_size, input_size] and dtype int32. Indices
- into embedding matrix.
- outputs: Tensor of shape [batch_size, embedding_size]. Outputs
- produced by layer.
- approx: str or None. If not None must be "kron". The Fisher
- approximation to use. If None the default value is used. (Default: None)
- reuse: bool or str. If True, this adds `inputs` and `outputs` as an
- additional mini-batch/tower of data to use when estimating the Fisher
- block for this layer (which must have already been registered). If
- "VARIABLE_SCOPE", use tf.get_variable_scope().reuse.
- (Default: "VARIABLE_SCOPE")
-
- Raises:
- ValueError: For improper value to `approx`.
- KeyError: If reuse == True but no FisherBlock found for `params`.
- ValueError: If reuse == True and FisherBlock found but of the wrong type.
- """
- block_type, approx = self._get_block_type(
- params, approx, self.default_embedding_approximation,
- _EMBEDDING_APPROX_TO_BLOCK_TYPES)
-
- if isinstance(params, (tuple, list)):
- raise ValueError("Bias not supported.")
- vocab_size = int(params.shape[0])
- block = self.register_block(
- params, block_type(self, vocab_size), reuse=reuse)
- block.register_additional_tower(inputs, outputs)
-
- self._add_uses(params, 1)
-
- def register_fully_connected(self,
- params,
- inputs,
- outputs,
- approx=None,
- reuse=VARIABLE_SCOPE):
- """Registers a fully connected layer.
-
- Args:
- params: Tensor or 2-tuple of Tensors corresponding to weight and bias of
- this layer. Weight matrix should have shape [input_size, output_size].
- Bias should have shape [output_size].
- inputs: Tensor of shape [batch_size, input_size]. Inputs to layer.
- outputs: Tensor of shape [batch_size, output_size]. Outputs
- produced by layer.
- approx: str or None. If not None must be one of "kron" or "diagonal".
- The Fisher approximation to use. If None the default value is used.
- (Default: None)
- reuse: bool or str. If True, this adds `inputs` and `outputs` as an
- additional mini-batch/tower of data to use when estimating the Fisher
- block for this layer (which must have already been registered). If
- "VARIABLE_SCOPE", use tf.get_variable_scope().reuse.
- (Default: "VARIABLE_SCOPE")
-
- Raises:
- ValueError: For improper value to `approx`.
- KeyError: If reuse == True but no FisherBlock found for `params`.
- ValueError: If reuse == True and FisherBlock found but of the wrong type.
- """
-
- block_type, approx = self._get_block_type(
- params, approx, self.default_fully_connected_approximation,
- _FULLY_CONNECTED_APPROX_TO_BLOCK_TYPES)
-
- has_bias = isinstance(params, (tuple, list))
- block = self.register_block(params, block_type(self, has_bias=has_bias),
- reuse=reuse)
- block.register_additional_tower(inputs, outputs)
-
- self._add_uses(params, 1)
-
- def register_conv2d(self,
- params,
- strides,
- padding,
- inputs,
- outputs,
- data_format=None,
- dilations=None,
- approx=None,
- reuse=VARIABLE_SCOPE):
- """Registers a call to tf.nn.conv2d().
-
- Args:
- params: Tensor or 2-tuple of Tensors corresponding to weight and bias of
- this layer. Weight matrix should have shape [kernel_height,
- kernel_width, in_channels, out_channels]. Bias should have shape
- [out_channels].
- strides: List of 4 ints. Strides for convolution kernel.
- padding: string. see tf.nn.conv2d for valid values.
- inputs: Tensor of shape [batch_size, height, width, in_channels]. Inputs
- to layer.
- outputs: Tensor of shape [batch_size, height, width, out_channels].
- Output produced by layer.
- data_format: str or None. Format of data.
- dilations: List of 4 ints. Dilations along each dimension.
- approx: str or None. If not None must be one of "kron" or "diagonal".
- The Fisher approximation to use. If None the default value is used.
- (Default: None)
- reuse: bool or str. If True, this adds `inputs` and `outputs` as an
- additional mini-batch/tower of data to use when estimating the Fisher
- block for this layer (which must have already been registered). If
- "VARIABLE_SCOPE", use tf.get_variable_scope().reuse.
- (Default: "VARIABLE_SCOPE")
-
- Raises:
- ValueError: For improper value to `approx`.
- KeyError: If reuse == True but no FisherBlock found for `params`.
- ValueError: If reuse == True and FisherBlock found but of the wrong type.
- """
-
- block_type, approx = self._get_block_type(
- params, approx, self.default_conv2d_approximation,
- _CONV2D_APPROX_TO_BLOCK_TYPES)
-
- # It feels bad to pass in configuration that has to do with the internal
- # implementation. And then we can`t use the same constructor for both
- # anymore and are thus forced to use this ugly if-statement.
- # TODO(b/74793309): Clean this up?
- if approx == APPROX_KRONECKER_NAME:
- block = self.register_block(
- params,
- block_type(
- layer_collection=self,
- params=params,
- padding=padding,
- strides=strides,
- data_format=data_format,
- dilation_rate=dilations,
- extract_patches_fn="extract_image_patches"),
- reuse=reuse)
- elif approx == APPROX_DIAGONAL_NAME:
- assert strides[0] == strides[-1] == 1
- block = self.register_block(
- params,
- block_type(
- layer_collection=self,
- params=params,
- padding=padding,
- strides=strides,
- dilations=dilations,
- data_format=data_format),
- reuse=reuse)
- else:
- raise NotImplementedError(approx)
-
- block.register_additional_tower(inputs, outputs)
-
- self._add_uses(params, 1)
-
- def register_convolution(self,
- params,
- inputs,
- outputs,
- padding,
- strides=None,
- dilation_rate=None,
- data_format=None,
- approx=None,
- reuse=VARIABLE_SCOPE):
- """Register a call to tf.nn.convolution().
-
- Args:
- params: Tensor or 2-tuple of Tensors corresponding to weight and bias of
- this layer. Weight matrix should have shape [..filter_spatial_size..,
- in_channels, out_channels]. Bias should have shape [out_channels].
- inputs: Tensor of shape [batch_size, ..input_spatial_size.., in_channels].
- Inputs to layer.
- outputs: Tensor of shape [batch_size, ..output_spatial_size..,
- out_channels]. Output produced by layer.
- padding: string. see tf.nn.conv2d for valid values.
- strides: List of ints of length len(..input_spatial_size..). Strides for
- convolution kernel in spatial dimensions.
- dilation_rate: List of ints of length len(..input_spatial_size..).
- Dilations along spatial dimension.
- data_format: str or None. Format of data.
- approx: str or None. If not None must be one of "kron" or "diagonal".
- The Fisher approximation to use. If None the default value is used.
- (Default: None)
- reuse: bool or str. If True, this adds `inputs` and `outputs` as an
- additional mini-batch/tower of data to use when estimating the Fisher
- block for this layer (which must have already been registered). If
- "VARIABLE_SCOPE", use tf.get_variable_scope().reuse.
- (Default: "VARIABLE_SCOPE")
-
- Raises:
- ValueError: For improper value to `approx`.
- KeyError: If reuse == True but no FisherBlock found for `params`.
- ValueError: If reuse == True and FisherBlock found but of the wrong type.
- """
- # TODO(b/74793309): Have this use _get_block_type like the other
- # registration functions?
- assert approx is None or approx == APPROX_KRONECKER_NAME
-
- block = self.register_block(
- params,
- fb.ConvKFCBasicFB(
- layer_collection=self,
- params=params,
- padding=padding,
- strides=strides,
- dilation_rate=dilation_rate,
- data_format=data_format),
- reuse=reuse)
- block.register_additional_tower(inputs, outputs)
-
- self._add_uses(params, 1)
-
- def register_depthwise_conv2d(self,
- params,
- inputs,
- outputs,
- strides,
- padding,
- rate=None,
- data_format=None,
- approx=None,
- reuse=VARIABLE_SCOPE):
- """Register a call to tf.nn.depthwise_conv2d().
-
- Args:
- params: 4-D Tensor of shape [filter_height, filter_width,
- in_channels, channel_multiplier]. Convolutional filter.
- inputs: Tensor of shape [batch_size, input_height, input_width,
- in_channels]. Inputs to layer.
- outputs: Tensor of shape [batch_size, output_height, output_width,
- in_channels * channel_multiplier]. Output produced by depthwise conv2d.
- strides: List of ints of length 4. Strides along all dimensions.
- padding: string. see tf.nn.conv2d for valid values.
- rate: None or List of ints of length 2. Dilation rates in spatial
- dimensions.
- data_format: str or None. Format of data.
- approx: str or None. If not None must "diagonal". The Fisher
- approximation to use. If None the default value is used. (Default: None)
- reuse: bool or str. If True, this adds `inputs` and `outputs` as an
- additional mini-batch/tower of data to use when estimating the Fisher
- block for this layer (which must have already been registered). If
- "VARIABLE_SCOPE", use tf.get_variable_scope().reuse.
- (Default: "VARIABLE_SCOPE")
-
- Raises:
- ValueError: For improper value to `approx`.
- KeyError: If reuse == True but no FisherBlock found for `params`.
- ValueError: If reuse == True and FisherBlock found but of the wrong type.
- """
- # TODO(b/74793309): Have this use _get_block_type like the other
- # registration functions?
- assert approx is None or approx == APPROX_DIAGONAL_NAME
- assert data_format in [None, "NHWC"]
-
- block = self.register_block(
- params,
- fb.DepthwiseConvDiagonalFB(
- layer_collection=self,
- params=params,
- strides=strides,
- padding=padding,
- rate=rate,
- data_format=data_format),
- reuse=reuse)
- block.register_additional_tower(inputs, outputs)
-
- self._add_uses(params, 1)
-
- def register_separable_conv2d(self,
- depthwise_params,
- pointwise_params,
- inputs,
- depthwise_outputs,
- pointwise_outputs,
- strides,
- padding,
- rate=None,
- data_format=None,
- approx=None,
- reuse=VARIABLE_SCOPE):
- """Register a call to tf.nn.separable_conv2d().
-
- Note: This requires access to intermediate outputs between depthwise and
- pointwise convolutions.
-
- Args:
- depthwise_params: 4-D Tensor of shape [filter_height, filter_width,
- in_channels, channel_multiplier]. Filter for depthwise conv2d.
- pointwise_params: 4-D Tensor of shape [1, 1, in_channels *
- channel_multiplier, out_channels]. Filter for pointwise conv2d.
- inputs: Tensor of shape [batch_size, input_height, input_width,
- in_channels]. Inputs to layer.
- depthwise_outputs: Tensor of shape [batch_size, output_height,
- output_width, in_channels * channel_multiplier]. Output produced by
- depthwise conv2d.
- pointwise_outputs: Tensor of shape [batch_size, output_height,
- output_width, out_channels]. Output produced by pointwise conv2d.
- strides: List of ints of length 4. Strides for depthwise conv2d kernel in
- all dimensions.
- padding: string. see tf.nn.conv2d for valid values.
- rate: None or List of ints of length 2. Dilation rate of depthwise conv2d
- kernel in spatial dimensions.
- data_format: str or None. Format of data.
- approx: str or None. If not None must be one of "kron" or "diagonal".
- The Fisher approximation to use. If None the default value is used.
- (Default: None)
- reuse: bool or str. If True, this adds `inputs` and `outputs` as an
- additional mini-batch/tower of data to use when estimating the Fisher
- block for this layer (which must have already been registered). If
- "VARIABLE_SCOPE", use tf.get_variable_scope().reuse.
- (Default: "VARIABLE_SCOPE")
-
- Raises:
- ValueError: For improper value to `approx`.
- KeyError: If reuse == True but no FisherBlock found for `params`.
- ValueError: If reuse == True and FisherBlock found but of the wrong type.
- """
- self.register_depthwise_conv2d(
- params=depthwise_params,
- inputs=inputs,
- outputs=depthwise_outputs,
- strides=strides,
- padding=padding,
- rate=rate,
- data_format=data_format,
- approx=APPROX_DIAGONAL_NAME,
- reuse=reuse)
-
- self.register_conv2d(
- params=pointwise_params,
- inputs=depthwise_outputs,
- outputs=pointwise_outputs,
- strides=[1, 1, 1, 1],
- padding="VALID",
- data_format=data_format,
- approx=approx,
- reuse=reuse)
-
- def register_generic(self,
- params,
- batch_size,
- approx=None,
- reuse=VARIABLE_SCOPE):
- """Registers a generic layer.
-
- Args:
- params: Tensor or tuple of Tensors corresponding to the parameters.
- batch_size: 0-D Tensor. Size of the minibatch (for this tower).
- approx: str or None. It not None, must be one of "full" or "diagonal".
- The Fisher approximation to use. If None the default value is used.
- (Default: None)
- reuse: bool or str. If True, this adds `batch_size` to the total
- mini-batch size use when estimating the Fisher block for this layer
- (which must have already been registered). If "VARIABLE_SCOPE", use
- tf.get_variable_scope().reuse. (Default: "VARIABLE_SCOPE")
-
- Raises:
- ValueError: For improper value to `approx`.
- KeyError: If reuse == True but no FisherBlock found for `params`.
- ValueError: If reuse == True and FisherBlock found but of the wrong type.
- """
- block_type, approx = self._get_block_type(
- params, approx, self.default_generic_approximation,
- _GENERIC_APPROX_TO_BLOCK_TYPES)
-
- block = self.register_block(params, block_type(self, params), reuse=reuse)
- block.register_additional_tower(batch_size)
-
- self._add_uses(params, float("inf"))
-
- def register_fully_connected_multi(self, params, inputs, outputs,
- num_uses=None, approx=None,
- reuse=VARIABLE_SCOPE):
- """Register fully connected layers with shared parameters.
-
- This can handle general fully-connected layers with shared parameters, but
- has specialized approximations to deal with the case where there is a
- meaningful linear order to the share instances (such as in an RNN).
-
- Args:
- params: Tensor or 2-tuple of Tensors corresponding to weight and bias of
- this layer. Weight matrix should have shape [input_size, output_size].
- Bias should have shape [output_size].
- inputs: A list of Tensors, each of shape [batch_size, input_size]. Inputs
- to layer. The list indexes each use in the graph (which might
- correspond to a "time-step" in an RNN). OR, can be single Tensor, of
- shape [num_uses * batch_size , input_size], which is a reshaped version
- of a Tensor of shape [num_uses, batch_size, input_size].
- outputs: A list of Tensors, the same length as `inputs`, each of shape
- [batch_size, output_size]. Outputs produced by layer. The list indexes
- each use in the graph (which might correspond to a "time-step" in an
- RNN). Needs to correspond with the order used in `inputs`. OR, can be
- a single Tensor of shape [num_uses * batch_size, output_size], which is
- a reshaped version of a Tensor of shape [num_uses, batch_size,
- output_size].
- num_uses: int or None. The number uses/time-steps in the graph where the
- layer appears. Only needed if both inputs and outputs are given in the
- single Tensor format. (Default: None)
- approx: str or None. If not None, must be of "kron_indep", "kron_series_1"
- or "kron_series_2". The Fisher approximation to use. If None the default
- value is used. (Default: None)
- reuse: bool or str. If True, this adds `inputs` and `outputs` as an
- additional mini-batch/tower of data to use when estimating the Fisher
- block for this layer (which must have already been registered). If
- "VARIABLE_SCOPE", use tf.get_variable_scope().reuse. (Note that the
- word `use` here has a completely different meaning to "use in the graph"
- as it pertains to the `inputs`, `outputs`, and `num_uses` arguments.)
- (Default: "VARIABLE_SCOPE")
-
- Raises:
- ValueError: For improper value to `approx`.
- """
- block_type, approx = self._get_block_type(
- params, approx, self.default_fully_connected_multi_approximation,
- _FULLY_CONNECTED_MULTI_APPROX_TO_BLOCK_TYPES)
-
- # TODO(b/70283649): something along the lines of find_canonical_output
- # should be added back in here (and for the other block types, arguably).
-
- has_bias = isinstance(params, (tuple, list))
- block = self.register_block(params, block_type(self, has_bias=has_bias,
- num_uses=num_uses),
- reuse=reuse)
- block.register_additional_tower(inputs, outputs)
- if isinstance(inputs, (tuple, list)):
- assert len(inputs) == len(outputs)
- self._add_uses(params, len(inputs))
- else:
- self._add_uses(params, 1)
-
- def register_conv2d_multi(self,
- params,
- strides,
- padding,
- inputs,
- outputs,
- num_uses=None,
- data_format=None,
- dilations=None,
- approx=None,
- reuse=VARIABLE_SCOPE):
- """Registers convolutional layers with shared parameters.
-
- Args:
- params: Tensor or 2-tuple of Tensors corresponding to weight and bias of
- this layer. Weight matrix should have shape [kernel_height,
- kernel_width, in_channels, out_channels]. Bias should have shape
- [out_channels].
- strides: 1-D Tensor of length 4. Strides for convolution kernel.
- padding: string. see tf.nn.conv2d for valid values.
- inputs: A list of Tensors, each of shape [batch_size, height, width,
- in_channels]. Inputs to layer. The list indexes each use in the graph
- (which might correspond to a "time-step" in an RNN). OR, can be single
- Tensor, of shape [num_uses * batch_size, height, width, in_channels],
- which is a reshaped version of a Tensor of shape [num_uses, batch_size,
- height, width, in_channels].
- outputs: A list of Tensors, each of shape [batch_size, height, width,
- out_channels]. Output produced by layer. The list indexes each use
- in the graph (which might correspond to a "time-step" in an RNN).
- Needs to correspond with the order used in `inputs`. OR, can be a
- single Tensor, of shape [num_uses * batch_size, height, width,
- out_channels], which is a reshaped version of a Tensor of shape
- [num_uses, batch_size, height, width, out_channels].
- num_uses: int or None. The number uses/time-steps in the graph where the
- layer appears. Only needed if both inputs and outputs are given in the
- single Tensor format. (Default: None)
- data_format: str or None. Format of data.
- dilations: List of 4 ints. Dilations along each dimension.
- approx: str or None. If not None must by "kron_indep". The Fisher
- approximation to use. If None the default value is used.
- (Default: None)
- reuse: bool or str. If True, this adds `inputs` and `outputs` as an
- additional mini-batch/tower of data to use when estimating the Fisher
- block for this layer (which must have already been registered). If
- "VARIABLE_SCOPE", use tf.get_variable_scope().reuse. (Note that the
- word `use` here has a completely different meaning to "use in the graph"
- as it pertains to the `inputs`, `outputs`, and `num_uses` arguments.)
- (Default: "VARIABLE_SCOPE")
-
- Raises:
- ValueError: For improper value to `approx`.
- KeyError: If reuse == True but no FisherBlock found for `params`.
- ValueError: If reuse == True and FisherBlock found but of the wrong type.
- """
- block_type, approx = self._get_block_type(
- params, approx, self.default_conv2d_multi_approximation,
- _CONV2D_MULTI_APPROX_TO_BLOCK_TYPES)
-
- block = self.register_block(
- params,
- block_type(
- layer_collection=self,
- params=params,
- padding=padding,
- strides=strides,
- data_format=data_format,
- dilation_rate=dilations,
- extract_patches_fn="extract_image_patches",
- num_uses=num_uses),
- reuse=reuse)
-
- block.register_additional_tower(inputs, outputs)
- if isinstance(inputs, (tuple, list)):
- assert len(inputs) == len(outputs)
- self._add_uses(params, len(inputs))
- else:
- self._add_uses(params, 1)
-
- # TODO(b/74108452): change the loss registration functions names to refer
- # to "loss functions" instead of distributions. Following naming convention
- # of the loss function classes themselves.
-
- def register_embedding_multi(self,
- params,
- inputs,
- outputs,
- num_uses=None,
- approx=None,
- reuse=VARIABLE_SCOPE):
- """Registers embedding layers with shared parameters.
-
- Args:
- params: Embedding matrix of shape [vocab_size, embedding_size].
- inputs: A list of Tensors, each of shape [batch_size, input_size] and
- dtype int32. Indices into embedding matrix. The list indexes each use
- in the graph (which might correspond to a "time-step" in an RNN).
- OR, can be single Tensor, of shape [num_uses*batch_size, input_size],
- which is a reshaped version of a Tensor of shape [num_uses, batch_size,
- input_size].
- outputs: A list of Tensors, each of shape [batch_size, embedding_size].
- Outputs produced by layer. The list indexes each use in the graph
- (which might correspond to a "time-step" in an RNN). Needs to
- correspond with the order used in `inputs`. OR, can be a
- single Tensor, of shape [num_uses * batch_size, embedding_size], which
- is a reshaped version of a Tensor of shape [num_uses, batch_size,
- embedding_size].
- num_uses: int or None. The number uses/time-steps in the graph where the
- layer appears. Only needed if both inputs and outputs are given in the
- single Tensor format. (Default: None)
- approx: str or None. If not None must by "kron_indep". The Fisher
- approximation to use. If None the default value is used.
- (Default: None)
- reuse: bool or str. If True, this adds `inputs` and `outputs` as an
- additional mini-batch/tower of data to use when estimating the Fisher
- block for this layer (which must have already been registered). If
- "VARIABLE_SCOPE", use tf.get_variable_scope().reuse. (Note that the
- word `use` here has a completely different meaning to "use in the graph"
- as it pertains to the `inputs`, `outputs`, and `num_uses` arguments.)
- (Default: "VARIABLE_SCOPE")
-
- Raises:
- ValueError: For improper value to `approx`.
- KeyError: If reuse == True but no FisherBlock found for `params`.
- ValueError: If reuse == True and FisherBlock found but of the wrong type.
- """
- block_type, approx = self._get_block_type(
- params, approx, self.default_embedding_multi_approximation,
- _EMBEDDING_MULTI_APPROX_TO_BLOCK_TYPES)
-
- if isinstance(params, (tuple, list)):
- raise ValueError("Bias not supported.")
- vocab_size = int(params.shape[0])
-
- block = self.register_block(
- params, block_type(self, vocab_size, num_uses=num_uses), reuse=reuse)
- block.register_additional_tower(inputs, outputs)
-
- if isinstance(inputs, (tuple, list)):
- self._add_uses(params, len(inputs))
- else:
- self._add_uses(params, 1)
-
- def register_categorical_predictive_distribution(self,
- logits,
- seed=None,
- targets=None,
- name=None,
- reuse=VARIABLE_SCOPE):
- """Registers a categorical predictive distribution.
-
- Args:
- logits: The logits of the distribution (i.e. its parameters).
- seed: The seed for the RNG (for debugging) (Default: None)
- targets: (OPTIONAL) The targets for the loss function. Only required if
- one wants to call total_loss() instead of total_sampled_loss().
- total_loss() is required, for example, to estimate the
- "empirical Fisher" (instead of the true Fisher).
- (Default: None)
- name: (OPTIONAL) str or None. Unique name for this loss function. If None,
- a new name is generated. (Default: None)
- reuse: bool or str. If True, this adds `logits` as an additional
- mini-batch/tower of inputs to the loss-function/predictive distribution
- (which must have already been registered). If "VARIABLE_SCOPE", use
- tf.get_variable_scope().reuse. (Default: "VARIABLE_SCOPE")
- """
- loss = lf.CategoricalLogitsNegativeLogProbLoss(logits, targets=targets,
- seed=seed)
- self.register_loss_function(loss, logits,
- "categorical_predictive_distribution",
- name=name, reuse=reuse)
-
- def register_normal_predictive_distribution(self,
- mean,
- var=0.5,
- seed=None,
- targets=None,
- name=None,
- reuse=VARIABLE_SCOPE):
- """Registers a normal predictive distribution.
-
- Args:
- mean: The mean vector defining the distribution.
- var: The variance (must be a scalar). Note that the default value of
- 0.5 corresponds to a standard squared error loss (target -
- prediction)**2. If your squared error loss is of the form
- 0.5*(target - prediction)**2 you should use var=1.0. (Default: 0.5)
- seed: The seed for the RNG (for debugging) (Default: None)
- targets: (OPTIONAL) The targets for the loss function. Only required if
- one wants to call total_loss() instead of total_sampled_loss().
- total_loss() is required, for example, to estimate the
- "empirical Fisher" (instead of the true Fisher).
- (Default: None)
- name: (OPTIONAL) str or None. Unique name for this loss function. If None,
- a new name is generated. (Default: None)
- reuse: bool or str. If True, this adds `mean` and `var` as an additional
- mini-batch/tower of inputs to the loss-function/predictive distribution
- (which must have already been registered). If "VARIABLE_SCOPE", use
- tf.get_variable_scope().reuse. (Default: "VARIABLE_SCOPE")
- """
- loss = lf.NormalMeanNegativeLogProbLoss(mean, var, targets=targets,
- seed=seed)
- self.register_loss_function(loss, mean,
- "normal_predictive_distribution",
- name=name, reuse=reuse)
-
- def register_multi_bernoulli_predictive_distribution(self,
- logits,
- seed=None,
- targets=None,
- name=None,
- reuse=VARIABLE_SCOPE):
- """Registers a multi-Bernoulli predictive distribution.
-
- Args:
- logits: The logits of the distribution (i.e. its parameters).
- seed: The seed for the RNG (for debugging) (Default: None)
- targets: (OPTIONAL) The targets for the loss function. Only required if
- one wants to call total_loss() instead of total_sampled_loss().
- total_loss() is required, for example, to estimate the
- "empirical Fisher" (instead of the true Fisher).
- (Default: None)
- name: (OPTIONAL) str or None. Unique name for this loss function. If None,
- a new name is generated. (Default: None)
- reuse: bool or str. If True, this adds `logits` as an additional
- mini-batch/tower of inputs to the loss-function/predictive distribution
- (which must have already been registered). If "VARIABLE_SCOPE", use
- tf.get_variable_scope().reuse. (Default: "VARIABLE_SCOPE")
- """
- loss = lf.MultiBernoulliNegativeLogProbLoss(logits, targets=targets,
- seed=seed)
- self.register_loss_function(loss, logits,
- "multi_bernoulli_predictive_distribution",
- name=name, reuse=reuse)
-
- def make_or_get_factor(self, cls, args):
- """Insert `cls(args)` into 'self.fisher_factors` if not already present.
-
- Wraps constructor in `tf.variable_scope()` to ensure variables constructed
- in `cls.__init__` are placed under this LayerCollection's scope.
-
- Args:
- cls: Class that implements FisherFactor.
- args: Tuple of arguments to pass into `cls's constructor. Must be
- hashable.
-
- Returns:
- Instance of `cls` found in self.fisher_factors.
- """
- try:
- hash(args)
- except TypeError:
- raise TypeError(
- ("Unable to use (cls, args) = ({}, {}) as a key in "
- "LayerCollection.fisher_factors. The pair cannot be hashed.").format(
- cls, args))
-
- key = cls, args
- if key not in self.fisher_factors:
- with variable_scope.variable_scope(self._var_scope):
- self.fisher_factors[key] = cls(*args)
- return self.fisher_factors[key]
-
- @contextmanager
- def as_default(self):
- """Sets this LayerCollection as the default."""
- set_default_layer_collection(self)
- yield
- set_default_layer_collection(None)
diff --git a/tensorflow/contrib/kfac/python/ops/layer_collection_lib.py b/tensorflow/contrib/kfac/python/ops/layer_collection_lib.py
deleted file mode 100644
index 9f46853807..0000000000
--- a/tensorflow/contrib/kfac/python/ops/layer_collection_lib.py
+++ /dev/null
@@ -1,46 +0,0 @@
-# Copyright 2017 The TensorFlow Authors. All Rights Reserved.
-#
-# Licensed under the Apache License, Version 2.0 (the "License");
-# you may not use this file except in compliance with the License.
-# You may obtain a copy of the License at
-#
-# http://www.apache.org/licenses/LICENSE-2.0
-#
-# Unless required by applicable law or agreed to in writing, software
-# distributed under the License is distributed on an "AS IS" BASIS,
-# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-# See the License for the specific language governing permissions and
-# limitations under the License.
-# ==============================================================================
-"""Registry for layers and their parameters/variables.
-
-This represents the collection of all layers in the approximate Fisher
-information matrix to which a particular FisherBlock may belong. That is, we
-might have several layer collections for one TF graph (if we have multiple K-FAC
-optimizers being used, for example.)
-"""
-
-from __future__ import absolute_import
-from __future__ import division
-from __future__ import print_function
-
-# pylint: disable=unused-import,line-too-long,wildcard-import
-from tensorflow.contrib.kfac.python.ops.layer_collection import *
-from tensorflow.python.util.all_util import remove_undocumented
-# pylint: enable=unused-import,line-too-long,wildcard-import
-
-_allowed_symbols = [
- "get_default_layer_collection",
- "set_default_layer_collection",
- "LayerParametersDict",
- "LayerCollection",
- "APPROX_KRONECKER_NAME",
- "APPROX_DIAGONAL_NAME",
- "APPROX_FULL_NAME",
- "VARIABLE_SCOPE",
- "APPROX_KRONECKER_INDEP_NAME",
- "APPROX_KRONECKER_SERIES_1_NAME",
- "APPROX_KRONECKER_SERIES_2_NAME"
-]
-
-remove_undocumented(__name__, allowed_exception_list=_allowed_symbols)
diff --git a/tensorflow/contrib/kfac/python/ops/linear_operator.py b/tensorflow/contrib/kfac/python/ops/linear_operator.py
deleted file mode 100644
index 61cb955ae8..0000000000
--- a/tensorflow/contrib/kfac/python/ops/linear_operator.py
+++ /dev/null
@@ -1,95 +0,0 @@
-# Copyright 2018 The TensorFlow Authors. All Rights Reserved.
-#
-# Licensed under the Apache License, Version 2.0 (the "License");
-# you may not use this file except in compliance with the License.
-# You may obtain a copy of the License at
-#
-# http://www.apache.org/licenses/LICENSE-2.0
-#
-# Unless required by applicable law or agreed to in writing, software
-# distributed under the License is distributed on an "AS IS" BASIS,
-# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-# See the License for the specific language governing permissions and
-# limitations under the License.
-# ==============================================================================
-"""SmartMatrices definitions."""
-
-from __future__ import absolute_import
-from __future__ import division
-from __future__ import print_function
-
-from tensorflow.contrib.kfac.python.ops import utils
-from tensorflow.python.framework import ops
-from tensorflow.python.ops import math_ops
-from tensorflow.python.ops.linalg import linalg
-from tensorflow.python.ops.linalg import linalg_impl
-from tensorflow.python.ops.linalg import linear_operator_util as lou
-
-
-class LinearOperatorExtras(object): # pylint: disable=missing-docstring
-
- def matmul(self, x, adjoint=False, adjoint_arg=False, name="matmul"):
-
- with self._name_scope(name, values=[x]):
- if isinstance(x, ops.IndexedSlices):
- return self._matmul_sparse(x, adjoint=adjoint, adjoint_arg=adjoint_arg)
-
- x = ops.convert_to_tensor(x, name="x")
- self._check_input_dtype(x)
-
- self_dim = -2 if adjoint else -1
- arg_dim = -1 if adjoint_arg else -2
- self.shape[self_dim].assert_is_compatible_with(x.get_shape()[arg_dim])
-
- return self._matmul(x, adjoint=adjoint, adjoint_arg=adjoint_arg)
-
- def matmul_right(self, x, adjoint=False, adjoint_arg=False, name="matmul"):
-
- with self._name_scope(name, values=[x]):
-
- if isinstance(x, ops.IndexedSlices):
- return self._matmul_right_sparse(
- x, adjoint=adjoint, adjoint_arg=adjoint_arg)
-
- x = ops.convert_to_tensor(x, name="x")
- self._check_input_dtype(x)
-
- self_dim = -1 if adjoint else -2
- arg_dim = -2 if adjoint_arg else -1
- self.shape[self_dim].assert_is_compatible_with(x.get_shape()[arg_dim])
-
- return self._matmul_right(x, adjoint=adjoint, adjoint_arg=adjoint_arg)
-
-
-class LinearOperatorFullMatrix(LinearOperatorExtras,
- linalg.LinearOperatorFullMatrix):
-
- # TODO(b/78117889) Remove this definition once core LinearOperator
- # has _matmul_right.
- def _matmul_right(self, x, adjoint=False, adjoint_arg=False):
- return lou.matmul_with_broadcast(
- x, self._matrix, adjoint_a=adjoint_arg, adjoint_b=adjoint)
-
- def _matmul_sparse(self, x, adjoint=False, adjoint_arg=False):
- raise NotImplementedError
-
- def _matmul_right_sparse(self, x, adjoint=False, adjoint_arg=False):
- assert not adjoint and not adjoint_arg
- return utils.matmul_sparse_dense(x, self._matrix)
-
-
-class LinearOperatorDiag(LinearOperatorExtras, # pylint: disable=missing-docstring
- linalg.LinearOperatorDiag):
-
- def _matmul_right(self, x, adjoint=False, adjoint_arg=False):
- diag_mat = math_ops.conj(self._diag) if adjoint else self._diag
- x = linalg_impl.adjoint(x) if adjoint_arg else x
- return diag_mat * x
-
- def _matmul_sparse(self, x, adjoint=False, adjoint_arg=False):
- diag_mat = math_ops.conj(self._diag) if adjoint else self._diag
- assert not adjoint_arg
- return utils.matmul_diag_sparse(diag_mat, x)
-
- def _matmul_right_sparse(self, x, adjoint=False, adjoint_arg=False):
- raise NotImplementedError
diff --git a/tensorflow/contrib/kfac/python/ops/loss_functions.py b/tensorflow/contrib/kfac/python/ops/loss_functions.py
deleted file mode 100644
index c8cebc42cb..0000000000
--- a/tensorflow/contrib/kfac/python/ops/loss_functions.py
+++ /dev/null
@@ -1,754 +0,0 @@
-# Copyright 2017 The TensorFlow Authors. All Rights Reserved.
-#
-# Licensed under the Apache License, Version 2.0 (the "License");
-# you may not use this file except in compliance with the License.
-# You may obtain a copy of the License at
-#
-# http://www.apache.org/licenses/LICENSE-2.0
-#
-# Unless required by applicable law or agreed to in writing, software
-# distributed under the License is distributed on an "AS IS" BASIS,
-# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-# See the License for the specific language governing permissions and
-# limitations under the License.
-# ==============================================================================
-"""Loss functions to be used by LayerCollection."""
-
-from __future__ import absolute_import
-from __future__ import division
-from __future__ import print_function
-
-import abc
-
-import six
-
-from tensorflow.contrib.distributions.python.ops import onehot_categorical
-from tensorflow.python.framework import tensor_shape
-from tensorflow.python.ops import array_ops
-from tensorflow.python.ops import math_ops
-from tensorflow.python.ops.distributions import bernoulli
-from tensorflow.python.ops.distributions import categorical
-from tensorflow.python.ops.distributions import normal
-
-
-@six.add_metaclass(abc.ABCMeta)
-class LossFunction(object):
- """Abstract base class for loss functions.
-
- Note that unlike typical loss functions used in neural networks these are
- summed and not averaged across cases in the batch, since this is what the
- users of this class (FisherEstimator and MatrixVectorProductComputer) will
- be expecting. The implication of this is that you will may want to
- normalize things like Fisher-vector products by the batch size when you
- use this class. It depends on the use case.
- """
-
- @abc.abstractproperty
- def targets(self):
- """The targets being predicted by the model.
-
- Returns:
- None or Tensor of appropriate shape for calling self._evaluate() on.
- """
- pass
-
- @abc.abstractproperty
- def inputs(self):
- """The inputs to the loss function (excluding the targets)."""
- pass
-
- def evaluate(self):
- """Evaluate the loss function on the targets."""
- if self.targets is not None:
- # We treat the targets as "constant". It's only the inputs that get
- # "back-propped" through.
- return self._evaluate(array_ops.stop_gradient(self.targets))
- else:
- raise Exception("Cannot evaluate losses with unspecified targets.")
-
- @abc.abstractmethod
- def _evaluate(self, targets):
- """Evaluates the negative log probability of the targets.
-
- Args:
- targets: Tensor that distribution can calculate log_prob() of.
-
- Returns:
- negative log probability of each target, summed across all targets.
- """
- pass
-
- @abc.abstractmethod
- def multiply_hessian(self, vector):
- """Right-multiply a vector by the Hessian.
-
- Here the 'Hessian' is the Hessian matrix (i.e. matrix of 2nd-derivatives)
- of the loss function with respect to its inputs.
-
- Args:
- vector: The vector to multiply. Must be the same shape(s) as the
- 'inputs' property.
-
- Returns:
- The vector right-multiplied by the Hessian. Will be of the same shape(s)
- as the 'inputs' property.
- """
- pass
-
- @abc.abstractmethod
- def multiply_hessian_factor(self, vector):
- """Right-multiply a vector by a factor B of the Hessian.
-
- Here the 'Hessian' is the Hessian matrix (i.e. matrix of 2nd-derivatives)
- of the loss function with respect to its inputs. Typically this will be
- block-diagonal across different cases in the batch, since the loss function
- is typically summed across cases.
-
- Note that B can be any matrix satisfying B * B^T = H where H is the Hessian,
- but will agree with the one used in the other methods of this class.
-
- Args:
- vector: The vector to multiply. Must be of the shape given by the
- 'hessian_factor_inner_shape' property.
-
- Returns:
- The vector right-multiplied by B. Will be of the same shape(s) as the
- 'inputs' property.
- """
- pass
-
- @abc.abstractmethod
- def multiply_hessian_factor_transpose(self, vector):
- """Right-multiply a vector by the transpose of a factor B of the Hessian.
-
- Here the 'Hessian' is the Hessian matrix (i.e. matrix of 2nd-derivatives)
- of the loss function with respect to its inputs. Typically this will be
- block-diagonal across different cases in the batch, since the loss function
- is typically summed across cases.
-
- Note that B can be any matrix satisfying B * B^T = H where H is the Hessian,
- but will agree with the one used in the other methods of this class.
-
- Args:
- vector: The vector to multiply. Must be the same shape(s) as the
- 'inputs' property.
-
- Returns:
- The vector right-multiplied by B^T. Will be of the shape given by the
- 'hessian_factor_inner_shape' property.
- """
- pass
-
- @abc.abstractmethod
- def multiply_hessian_factor_replicated_one_hot(self, index):
- """Right-multiply a replicated-one-hot vector by a factor B of the Hessian.
-
- Here the 'Hessian' is the Hessian matrix (i.e. matrix of 2nd-derivatives)
- of the loss function with respect to its inputs. Typically this will be
- block-diagonal across different cases in the batch, since the loss function
- is typically summed across cases.
-
- A 'replicated-one-hot' vector means a tensor which, for each slice along the
- batch dimension (assumed to be dimension 0), is 1.0 in the entry
- corresponding to the given index and 0 elsewhere.
-
- Note that B can be any matrix satisfying B * B^T = H where H is the Hessian,
- but will agree with the one used in the other methods of this class.
-
- Args:
- index: A tuple representing in the index of the entry in each slice that
- is 1.0. Note that len(index) must be equal to the number of elements
- of the 'hessian_factor_inner_shape' tensor minus one.
-
- Returns:
- The vector right-multiplied by B^T. Will be of the same shape(s) as the
- 'inputs' property.
- """
- pass
-
- @abc.abstractproperty
- def hessian_factor_inner_shape(self):
- """The shape of the tensor returned by multiply_hessian_factor."""
- pass
-
- @abc.abstractproperty
- def hessian_factor_inner_static_shape(self):
- """Static version of hessian_factor_inner_shape."""
- pass
-
-
-@six.add_metaclass(abc.ABCMeta)
-class NegativeLogProbLoss(LossFunction):
- """Abstract base class for loss functions that are negative log probs."""
-
- def __init__(self, seed=None):
- self._default_seed = seed
- super(NegativeLogProbLoss, self).__init__()
-
- @property
- def inputs(self):
- return self.params
-
- @abc.abstractproperty
- def params(self):
- """Parameters to the underlying distribution."""
- pass
-
- @abc.abstractmethod
- def multiply_fisher(self, vector):
- """Right-multiply a vector by the Fisher.
-
- Args:
- vector: The vector to multiply. Must be the same shape(s) as the
- 'inputs' property.
-
- Returns:
- The vector right-multiplied by the Fisher. Will be of the same shape(s)
- as the 'inputs' property.
- """
- pass
-
- @abc.abstractmethod
- def multiply_fisher_factor(self, vector):
- """Right-multiply a vector by a factor B of the Fisher.
-
- Here the 'Fisher' is the Fisher information matrix (i.e. expected outer-
- product of gradients) with respect to the parameters of the underlying
- probability distribution (whose log-prob defines the loss). Typically this
- will be block-diagonal across different cases in the batch, since the
- distribution is usually (but not always) conditionally iid across different
- cases.
-
- Note that B can be any matrix satisfying B * B^T = F where F is the Fisher,
- but will agree with the one used in the other methods of this class.
-
- Args:
- vector: The vector to multiply. Must be of the shape given by the
- 'fisher_factor_inner_shape' property.
-
- Returns:
- The vector right-multiplied by B. Will be of the same shape(s) as the
- 'inputs' property.
- """
- pass
-
- @abc.abstractmethod
- def multiply_fisher_factor_transpose(self, vector):
- """Right-multiply a vector by the transpose of a factor B of the Fisher.
-
- Here the 'Fisher' is the Fisher information matrix (i.e. expected outer-
- product of gradients) with respect to the parameters of the underlying
- probability distribution (whose log-prob defines the loss). Typically this
- will be block-diagonal across different cases in the batch, since the
- distribution is usually (but not always) conditionally iid across different
- cases.
-
- Note that B can be any matrix satisfying B * B^T = F where F is the Fisher,
- but will agree with the one used in the other methods of this class.
-
- Args:
- vector: The vector to multiply. Must be the same shape(s) as the
- 'inputs' property.
-
- Returns:
- The vector right-multiplied by B^T. Will be of the shape given by the
- 'fisher_factor_inner_shape' property.
- """
- pass
-
- @abc.abstractmethod
- def multiply_fisher_factor_replicated_one_hot(self, index):
- """Right-multiply a replicated-one-hot vector by a factor B of the Fisher.
-
- Here the 'Fisher' is the Fisher information matrix (i.e. expected outer-
- product of gradients) with respect to the parameters of the underlying
- probability distribution (whose log-prob defines the loss). Typically this
- will be block-diagonal across different cases in the batch, since the
- distribution is usually (but not always) conditionally iid across different
- cases.
-
- A 'replicated-one-hot' vector means a tensor which, for each slice along the
- batch dimension (assumed to be dimension 0), is 1.0 in the entry
- corresponding to the given index and 0 elsewhere.
-
- Note that B can be any matrix satisfying B * B^T = H where H is the Fisher,
- but will agree with the one used in the other methods of this class.
-
- Args:
- index: A tuple representing in the index of the entry in each slice that
- is 1.0. Note that len(index) must be equal to the number of elements
- of the 'fisher_factor_inner_shape' tensor minus one.
-
- Returns:
- The vector right-multiplied by B. Will be of the same shape(s) as the
- 'inputs' property.
- """
- pass
-
- @abc.abstractproperty
- def fisher_factor_inner_shape(self):
- """The shape of the tensor returned by multiply_fisher_factor."""
- pass
-
- @abc.abstractproperty
- def fisher_factor_inner_static_shape(self):
- """Static version of fisher_factor_inner_shape."""
- pass
-
- @abc.abstractmethod
- def sample(self, seed):
- """Sample 'targets' from the underlying distribution."""
- pass
-
- def evaluate_on_sample(self, seed=None):
- """Evaluates the log probability on a random sample.
-
- Args:
- seed: int or None. Random seed for this draw from the distribution.
-
- Returns:
- Log probability of sampled targets, summed across examples.
- """
- if seed is None:
- seed = self._default_seed
- # We treat the targets as "constant". It's only the inputs that get
- # "back-propped" through.
- return self._evaluate(array_ops.stop_gradient(self.sample(seed)))
-
-
-# TODO(jamesmartens): should this just inherit from object to avoid "diamond"
-# inheritance, or is there a better way?
-class NaturalParamsNegativeLogProbLoss(NegativeLogProbLoss):
- """Base class for neg log prob losses whose inputs are 'natural' parameters.
-
- Note that the Hessian and Fisher for natural parameters of exponential-
- family models are the same, hence the purpose of this class.
- See here: https://arxiv.org/abs/1412.1193
-
- 'Natural parameters' are defined for exponential-family models. See for
- example: https://en.wikipedia.org/wiki/Exponential_family
- """
-
- def multiply_hessian(self, vector):
- return self.multiply_fisher(vector)
-
- def multiply_hessian_factor(self, vector):
- return self.multiply_fisher_factor(vector)
-
- def multiply_hessian_factor_transpose(self, vector):
- return self.multiply_fisher_factor_transpose(vector)
-
- def multiply_hessian_factor_replicated_one_hot(self, index):
- return self.multiply_fisher_factor_replicated_one_hot(index)
-
- @property
- def hessian_factor_inner_shape(self):
- return self.fisher_factor_inner_shape
-
- @property
- def hessian_factor_inner_static_shape(self):
- return self.fisher_factor_inner_shape
-
-
-class DistributionNegativeLogProbLoss(NegativeLogProbLoss):
- """Base class for neg log prob losses that use the TF Distribution classes."""
-
- def __init__(self, seed=None):
- super(DistributionNegativeLogProbLoss, self).__init__(seed=seed)
-
- @abc.abstractproperty
- def dist(self):
- """The underlying tf.distributions.Distribution."""
- pass
-
- def _evaluate(self, targets):
- return -math_ops.reduce_sum(self.dist.log_prob(targets))
-
- def sample(self, seed):
- return self.dist.sample(seed=seed)
-
-
-class NormalMeanNegativeLogProbLoss(DistributionNegativeLogProbLoss,
- NaturalParamsNegativeLogProbLoss):
- """Neg log prob loss for a normal distribution parameterized by a mean vector.
-
-
- Note that the covariance is treated as a constant 'var' times the identity.
- Also note that the Fisher for such a normal distribution with respect the mean
- parameter is given by:
-
- F = (1/var) * I
-
- See for example https://www.ii.pwr.edu.pl/~tomczak/PDF/[JMT]Fisher_inf.pdf.
- """
-
- def __init__(self, mean, var=0.5, targets=None, seed=None):
- self._mean = mean
- self._var = var
- self._targets = targets
- super(NormalMeanNegativeLogProbLoss, self).__init__(seed=seed)
-
- @property
- def targets(self):
- return self._targets
-
- @property
- def dist(self):
- return normal.Normal(loc=self._mean, scale=math_ops.sqrt(self._var))
-
- @property
- def params(self):
- return self._mean
-
- def multiply_fisher(self, vector):
- return (1. / self._var) * vector
-
- def multiply_fisher_factor(self, vector):
- return self._var**-0.5 * vector
-
- def multiply_fisher_factor_transpose(self, vector):
- return self.multiply_fisher_factor(vector) # it's symmetric in this case
-
- def multiply_fisher_factor_replicated_one_hot(self, index):
- assert len(index) == 1, "Length of index was {}".format(len(index))
- ones_slice = array_ops.expand_dims(
- array_ops.ones(array_ops.shape(self._mean)[:1], dtype=self._mean.dtype),
- axis=-1)
- output_slice = self._var**-0.5 * ones_slice
- return insert_slice_in_zeros(output_slice, 1, int(self._mean.shape[1]),
- index[0])
-
- @property
- def fisher_factor_inner_shape(self):
- return array_ops.shape(self._mean)
-
- @property
- def fisher_factor_inner_static_shape(self):
- return self._mean.shape
-
-
-class NormalMeanVarianceNegativeLogProbLoss(DistributionNegativeLogProbLoss):
- """Negative log prob loss for a normal distribution with mean and variance.
-
- This class parameterizes a multivariate normal distribution with n independent
- dimensions. Unlike `NormalMeanNegativeLogProbLoss`, this class does not
- assume the variance is held constant. The Fisher Information for n = 1
- is given by,
-
- F = [[1 / variance, 0],
- [ 0, 0.5 / variance^2]]
-
- where the parameters of the distribution are concatenated into a single
- vector as [mean, variance]. For n > 1, the mean parameter vector is
- concatenated with the variance parameter vector.
-
- See https://www.ii.pwr.edu.pl/~tomczak/PDF/[JMT]Fisher_inf.pdf for derivation.
- """
-
- def __init__(self, mean, variance, targets=None, seed=None):
- assert len(mean.shape) == 2, "Expect 2D mean tensor."
- assert len(variance.shape) == 2, "Expect 2D variance tensor."
- self._mean = mean
- self._variance = variance
- self._targets = targets
- super(NormalMeanVarianceNegativeLogProbLoss, self).__init__(seed=seed)
-
- @property
- def targets(self):
- return self._targets
-
- @property
- def dist(self):
- return normal.Normal(loc=self._mean, scale=math_ops.sqrt(self._variance))
-
- @property
- def params(self):
- return self._mean, self._variance
-
- def _concat(self, mean, variance):
- return array_ops.concat([mean, variance], axis=-1)
-
- def _split(self, params):
- return array_ops.split(params, 2, axis=-1)
-
- @property
- def _fisher_mean(self):
- return 1. / self._variance
-
- @property
- def _fisher_mean_factor(self):
- return 1. / math_ops.sqrt(self._variance)
-
- @property
- def _fisher_var(self):
- return 1. / (2 * math_ops.square(self._variance))
-
- @property
- def _fisher_var_factor(self):
- return 1. / (math_ops.sqrt(2.) * self._variance)
-
- def multiply_fisher(self, vecs):
- mean_vec, var_vec = vecs
- return (self._fisher_mean * mean_vec, self._fisher_var * var_vec)
-
- def multiply_fisher_factor(self, vecs):
- mean_vec, var_vec = self._split(vecs)
- return (self._fisher_mean_factor * mean_vec,
- self._fisher_var_factor * var_vec)
-
- def multiply_fisher_factor_transpose(self, vecs):
- mean_vec, var_vec = vecs
- return self._concat(self._fisher_mean_factor * mean_vec,
- self._fisher_var_factor * var_vec)
-
- def multiply_fisher_factor_replicated_one_hot(self, index):
- assert len(index) == 1, "Length of index was {}".format(len(index))
- index = index[0]
-
- if index < int(self._mean.shape[-1]):
- # Index corresponds to mean parameter.
- mean_slice = self._fisher_mean_factor[:, index]
- mean_slice = array_ops.expand_dims(mean_slice, axis=-1)
- mean_output = insert_slice_in_zeros(mean_slice, 1, int(
- self._mean.shape[1]), index)
- var_output = array_ops.zeros_like(mean_output)
- else:
- index -= int(self._mean.shape[-1])
- # Index corresponds to variance parameter.
- var_slice = self._fisher_var_factor[:, index]
- var_slice = array_ops.expand_dims(var_slice, axis=-1)
- var_output = insert_slice_in_zeros(var_slice, 1,
- int(self._variance.shape[1]), index)
- mean_output = array_ops.zeros_like(var_output)
-
- return mean_output, var_output
-
- @property
- def fisher_factor_inner_shape(self):
- return array_ops.concat(
- [
- array_ops.shape(self._mean)[:-1],
- 2 * array_ops.shape(self._mean)[-1:]
- ],
- axis=0)
-
- @property
- def fisher_factor_inner_static_shape(self):
- shape = self._mean.shape.as_list()
- return tensor_shape.TensorShape(shape[-1:] + [2 * shape[-1]])
-
- def multiply_hessian(self, vector):
- raise NotImplementedError()
-
- def multiply_hessian_factor(self, vector):
- raise NotImplementedError()
-
- def multiply_hessian_factor_transpose(self, vector):
- raise NotImplementedError()
-
- def multiply_hessian_factor_replicated_one_hot(self, index):
- raise NotImplementedError()
-
- @property
- def hessian_factor_inner_shape(self):
- raise NotImplementedError()
-
- @property
- def hessian_factor_inner_static_shape(self):
- raise NotImplementedError()
-
-
-class CategoricalLogitsNegativeLogProbLoss(DistributionNegativeLogProbLoss,
- NaturalParamsNegativeLogProbLoss):
- """Neg log prob loss for a categorical distribution parameterized by logits.
-
-
- Note that the Fisher (for a single case) of a categorical distribution, with
- respect to the natural parameters (i.e. the logits), is given by:
-
- F = diag(p) - p*p^T
-
- where p = softmax(logits). F can be factorized as F = B * B^T where
-
- B = diag(q) - p*q^T
-
- where q is the entry-wise square root of p. This is easy to verify using the
- fact that q^T*q = 1.
- """
-
- def __init__(self, logits, targets=None, seed=None):
- """Instantiates a CategoricalLogitsNegativeLogProbLoss.
-
- Args:
- logits: Tensor of shape [batch_size, output_size]. Parameters for
- underlying distribution.
- targets: None or Tensor of shape [output_size]. Each elements contains an
- index in [0, output_size).
- seed: int or None. Default random seed when sampling.
- """
- self._logits = logits
- self._targets = targets
- super(CategoricalLogitsNegativeLogProbLoss, self).__init__(seed=seed)
-
- @property
- def targets(self):
- return self._targets
-
- @property
- def dist(self):
- return categorical.Categorical(logits=self._logits)
-
- @property
- def _probs(self):
- return self.dist.probs
-
- @property
- def _sqrt_probs(self):
- return math_ops.sqrt(self._probs)
-
- @property
- def params(self):
- return self._logits
-
- def multiply_fisher(self, vector):
- probs = self._probs
- return vector * probs - probs * math_ops.reduce_sum(
- vector * probs, axis=-1, keepdims=True)
-
- def multiply_fisher_factor(self, vector):
- probs = self._probs
- sqrt_probs = self._sqrt_probs
- return sqrt_probs * vector - probs * math_ops.reduce_sum(
- sqrt_probs * vector, axis=-1, keepdims=True)
-
- def multiply_fisher_factor_transpose(self, vector):
- probs = self._probs
- sqrt_probs = self._sqrt_probs
- return sqrt_probs * vector - sqrt_probs * math_ops.reduce_sum(
- probs * vector, axis=-1, keepdims=True)
-
- def multiply_fisher_factor_replicated_one_hot(self, index):
- assert len(index) == 1, "Length of index was {}".format(len(index))
- probs = self._probs
- sqrt_probs = self._sqrt_probs
- sqrt_probs_slice = array_ops.expand_dims(sqrt_probs[:, index[0]], -1)
- padded_slice = insert_slice_in_zeros(sqrt_probs_slice, 1,
- int(sqrt_probs.shape[1]), index[0])
- return padded_slice - probs * sqrt_probs_slice
-
- @property
- def fisher_factor_inner_shape(self):
- return array_ops.shape(self._logits)
-
- @property
- def fisher_factor_inner_static_shape(self):
- return self._logits.shape
-
-
-class MultiBernoulliNegativeLogProbLoss(DistributionNegativeLogProbLoss,
- NaturalParamsNegativeLogProbLoss):
- """Neg log prob loss for multiple Bernoulli distributions param'd by logits.
-
- Represents N independent Bernoulli distributions where N = len(logits). Its
- Fisher Information matrix is given by,
-
- F = diag(p * (1-p))
- p = sigmoid(logits)
-
- As F is diagonal with positive entries, its factor B is,
-
- B = diag(sqrt(p * (1-p)))
- """
-
- def __init__(self, logits, targets=None, seed=None):
- self._logits = logits
- self._targets = targets
- super(MultiBernoulliNegativeLogProbLoss, self).__init__(seed=seed)
-
- @property
- def targets(self):
- return self._targets
-
- @property
- def dist(self):
- return bernoulli.Bernoulli(logits=self._logits)
-
- @property
- def _probs(self):
- return self.dist.probs
-
- @property
- def params(self):
- return self._logits
-
- def multiply_fisher(self, vector):
- return self._probs * (1 - self._probs) * vector
-
- def multiply_fisher_factor(self, vector):
- return math_ops.sqrt(self._probs * (1 - self._probs)) * vector
-
- def multiply_fisher_factor_transpose(self, vector):
- return self.multiply_fisher_factor(vector) # it's symmetric in this case
-
- def multiply_fisher_factor_replicated_one_hot(self, index):
- assert len(index) == 1, "Length of index was {}".format(len(index))
- probs_slice = array_ops.expand_dims(self._probs[:, index[0]], -1)
- output_slice = math_ops.sqrt(probs_slice * (1 - probs_slice))
- return insert_slice_in_zeros(output_slice, 1, int(self._logits.shape[1]),
- index[0])
-
- @property
- def fisher_factor_inner_shape(self):
- return array_ops.shape(self._logits)
-
- @property
- def fisher_factor_inner_static_shape(self):
- return self._logits.shape
-
-
-def insert_slice_in_zeros(slice_to_insert, dim, dim_size, position):
- """Inserts slice into a larger tensor of zeros.
-
- Forms a new tensor which is the same shape as slice_to_insert, except that
- the dimension given by 'dim' is expanded to the size given by 'dim_size'.
- 'position' determines the position (index) at which to insert the slice within
- that dimension.
-
- Assumes slice_to_insert.shape[dim] = 1.
-
- Args:
- slice_to_insert: The slice to insert.
- dim: The dimension which to expand with zeros.
- dim_size: The new size of the 'dim' dimension.
- position: The position of 'slice_to_insert' in the new tensor.
-
- Returns:
- The new tensor.
-
- Raises:
- ValueError: If the slice's shape at the given dim is not 1.
- """
- slice_shape = slice_to_insert.shape
- if slice_shape[dim] != 1:
- raise ValueError("Expected slice_to_insert.shape to have {} dim of 1, but "
- "was {}".format(dim, slice_to_insert.shape[dim]))
-
- before = [0] * int(len(slice_shape))
- after = before[:]
- before[dim] = position
- after[dim] = dim_size - position - 1
-
- return array_ops.pad(slice_to_insert, list(zip(before, after)))
-
-
-class OnehotCategoricalLogitsNegativeLogProbLoss(
- CategoricalLogitsNegativeLogProbLoss):
- """Neg log prob loss for a categorical distribution with onehot targets.
-
- Identical to CategoricalLogitsNegativeLogProbLoss except that the underlying
- distribution is OneHotCategorical as opposed to Categorical.
- """
-
- @property
- def dist(self):
- return onehot_categorical.OneHotCategorical(logits=self._logits)
diff --git a/tensorflow/contrib/kfac/python/ops/loss_functions_lib.py b/tensorflow/contrib/kfac/python/ops/loss_functions_lib.py
deleted file mode 100644
index 4279cb2792..0000000000
--- a/tensorflow/contrib/kfac/python/ops/loss_functions_lib.py
+++ /dev/null
@@ -1,39 +0,0 @@
-# Copyright 2017 The TensorFlow Authors. All Rights Reserved.
-#
-# Licensed under the Apache License, Version 2.0 (the "License");
-# you may not use this file except in compliance with the License.
-# You may obtain a copy of the License at
-#
-# http://www.apache.org/licenses/LICENSE-2.0
-#
-# Unless required by applicable law or agreed to in writing, software
-# distributed under the License is distributed on an "AS IS" BASIS,
-# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-# See the License for the specific language governing permissions and
-# limitations under the License.
-# ==============================================================================
-"""Loss functions to be used by LayerCollection."""
-
-from __future__ import absolute_import
-from __future__ import division
-from __future__ import print_function
-
-# pylint: disable=unused-import,line-too-long,wildcard-import
-from tensorflow.contrib.kfac.python.ops.loss_functions import *
-from tensorflow.python.util.all_util import remove_undocumented
-# pylint: enable=unused-import,line-too-long,wildcard-import
-
-_allowed_symbols = [
- "LossFunction",
- "NegativeLogProbLoss",
- "NaturalParamsNegativeLogProbLoss",
- "DistributionNegativeLogProbLoss",
- "NormalMeanNegativeLogProbLoss",
- "NormalMeanVarianceNegativeLogProbLoss",
- "CategoricalLogitsNegativeLogProbLoss",
- "OnehotCategoricalLogitsNegativeLogProbLoss",
- "MultiBernoulliNegativeLogProbLoss",
- "insert_slice_in_zeros",
-]
-
-remove_undocumented(__name__, allowed_exception_list=_allowed_symbols)
diff --git a/tensorflow/contrib/kfac/python/ops/op_queue.py b/tensorflow/contrib/kfac/python/ops/op_queue.py
deleted file mode 100644
index b6d9d37a31..0000000000
--- a/tensorflow/contrib/kfac/python/ops/op_queue.py
+++ /dev/null
@@ -1,69 +0,0 @@
-# Copyright 2017 The TensorFlow Authors. All Rights Reserved.
-#
-# Licensed under the Apache License, Version 2.0 (the "License");
-# you may not use this file except in compliance with the License.
-# You may obtain a copy of the License at
-#
-# http://www.apache.org/licenses/LICENSE-2.0
-#
-# Unless required by applicable law or agreed to in writing, software
-# distributed under the License is distributed on an "AS IS" BASIS,
-# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-# See the License for the specific language governing permissions and
-# limitations under the License.
-# ==============================================================================
-"""Helper for choosing which op to run next in a distributed setting."""
-
-from __future__ import absolute_import
-from __future__ import division
-from __future__ import print_function
-
-from tensorflow.python.data.ops import dataset_ops
-from tensorflow.python.framework import ops as tf_ops
-
-
-class OpQueue(object):
- """Class for choosing which Op to run next.
-
- Constructs an infinitely repeating sequence of Ops in shuffled order.
-
- In K-FAC, this can be used to distribute inverse update operations among
- workers.
- """
-
- def __init__(self, ops, seed=None):
- """Initializes an OpQueue.
-
- Args:
- ops: list of TensorFlow Ops. Ops to be selected from. All workers must
- initialize with the same set of ops.
- seed: int or None. Random seed used when shuffling order of ops.
- """
- self._ops_by_name = {op.name: op for op in ops}
-
- # Construct a (shuffled) Dataset with Op names.
- op_names = tf_ops.convert_to_tensor(list(sorted(op.name for op in ops)))
- op_names_dataset = (dataset_ops.Dataset.from_tensor_slices(op_names)
- .shuffle(len(ops), seed=seed).repeat())
- self._next_op_name = op_names_dataset.make_one_shot_iterator().get_next()
-
- @property
- def ops(self):
- """Ops this OpQueue can return in next_op()."""
- return self._ops_by_name.values()
-
- def next_op(self, sess):
- """Chooses which op to run next.
-
- Note: This call will make a call to sess.run().
-
- Args:
- sess: tf.Session.
-
- Returns:
- Next Op chosen from 'ops'.
- """
- # In Python 3, type(next_op_name) == bytes. Calling bytes.decode('ascii')
- # returns a str.
- next_op_name = sess.run(self._next_op_name).decode('ascii')
- return self._ops_by_name[next_op_name]
diff --git a/tensorflow/contrib/kfac/python/ops/op_queue_lib.py b/tensorflow/contrib/kfac/python/ops/op_queue_lib.py
deleted file mode 100644
index 09c9a4ab33..0000000000
--- a/tensorflow/contrib/kfac/python/ops/op_queue_lib.py
+++ /dev/null
@@ -1,30 +0,0 @@
-# Copyright 2017 The TensorFlow Authors. All Rights Reserved.
-#
-# Licensed under the Apache License, Version 2.0 (the "License");
-# you may not use this file except in compliance with the License.
-# You may obtain a copy of the License at
-#
-# http://www.apache.org/licenses/LICENSE-2.0
-#
-# Unless required by applicable law or agreed to in writing, software
-# distributed under the License is distributed on an "AS IS" BASIS,
-# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-# See the License for the specific language governing permissions and
-# limitations under the License.
-# ==============================================================================
-"""Helper for choosing which op to run next in a distributed setting."""
-
-from __future__ import absolute_import
-from __future__ import division
-from __future__ import print_function
-
-# pylint: disable=unused-import,line-too-long,wildcard-import
-from tensorflow.contrib.kfac.python.ops.op_queue import *
-from tensorflow.python.util.all_util import remove_undocumented
-# pylint: enable=unused-import,line-too-long,wildcard-import
-
-_allowed_symbols = [
- 'OpQueue',
-]
-
-remove_undocumented(__name__, allowed_exception_list=_allowed_symbols)
diff --git a/tensorflow/contrib/kfac/python/ops/optimizer.py b/tensorflow/contrib/kfac/python/ops/optimizer.py
deleted file mode 100644
index 38605259b5..0000000000
--- a/tensorflow/contrib/kfac/python/ops/optimizer.py
+++ /dev/null
@@ -1,727 +0,0 @@
-# Copyright 2017 The TensorFlow Authors. All Rights Reserved.
-#
-# Licensed under the Apache License, Version 2.0 (the "License");
-# you may not use this file except in compliance with the License.
-# You may obtain a copy of the License at
-#
-# http://www.apache.org/licenses/LICENSE-2.0
-#
-# Unless required by applicable law or agreed to in writing, software
-# distributed under the License is distributed on an "AS IS" BASIS,
-# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-# See the License for the specific language governing permissions and
-# limitations under the License.
-# ==============================================================================
-"""The KFAC optimizer."""
-
-from __future__ import absolute_import
-from __future__ import division
-from __future__ import print_function
-
-import warnings
-
-# pylint disable=long-line
-from tensorflow.contrib.kfac.python.ops import curvature_matrix_vector_products as cmvp
-from tensorflow.contrib.kfac.python.ops import estimator as est
-# pylint enable=long-line
-
-from tensorflow.python.framework import dtypes
-from tensorflow.python.framework import ops
-from tensorflow.python.ops import array_ops
-from tensorflow.python.ops import control_flow_ops
-from tensorflow.python.ops import linalg_ops
-from tensorflow.python.ops import math_ops
-from tensorflow.python.ops import state_ops
-from tensorflow.python.ops import variable_scope
-from tensorflow.python.ops import variables as tf_variables
-from tensorflow.python.training import gradient_descent
-
-
-class KfacOptimizer(gradient_descent.GradientDescentOptimizer):
- """The KFAC Optimizer (https://arxiv.org/abs/1503.05671)."""
-
- def __init__(self,
- learning_rate,
- cov_ema_decay,
- damping,
- layer_collection,
- var_list=None,
- momentum=0.9,
- momentum_type="regular",
- norm_constraint=None,
- name="KFAC",
- estimation_mode="gradients",
- colocate_gradients_with_ops=True,
- batch_size=None,
- placement_strategy=None,
- **kwargs):
- """Initializes the KFAC optimizer with the given settings.
-
- Args:
- learning_rate: The base learning rate for the optimizer. Should probably
- be set to 1.0 when using momentum_type = 'qmodel', but can still be
- set lowered if desired (effectively lowering the trust in the
- quadratic model.)
- cov_ema_decay: The decay factor used when calculating the covariance
- estimate moving averages.
- damping: The damping factor used to stabilize training due to errors in
- the local approximation with the Fisher information matrix, and to
- regularize the update direction by making it closer to the gradient.
- If damping is adapted during training then this value is used for
- initializing damping variable.
- (Higher damping means the update looks more like a standard gradient
- update - see Tikhonov regularization.)
- layer_collection: The layer collection object, which holds the fisher
- blocks, Kronecker factors, and losses associated with the
- graph. The layer_collection cannot be modified after KfacOptimizer's
- initialization.
- var_list: Optional list or tuple of variables to train. Defaults to the
- list of variables collected in the graph under the key
- `GraphKeys.TRAINABLE_VARIABLES`.
- momentum: The momentum decay constant to use. Only applies when
- momentum_type is 'regular' or 'adam'. (Default: 0.9)
- momentum_type: The type of momentum to use in this optimizer, one of
- 'regular', 'adam', or 'qmodel'. (Default: 'regular')
- norm_constraint: float or Tensor. If specified, the update is scaled down
- so that its approximate squared Fisher norm v^T F v is at most the
- specified value. May only be used with momentum type 'regular'.
- (Default: None)
- name: The name for this optimizer. (Default: 'KFAC')
- estimation_mode: The type of estimator to use for the Fishers. Can be
- 'gradients', 'empirical', 'curvature_propagation', or 'exact'.
- (Default: 'gradients'). See the doc-string for FisherEstimator for
- more a more detailed description of these options.
- colocate_gradients_with_ops: Whether we should request gradients we
- compute in the estimator be colocated with their respective ops.
- (Default: True)
- batch_size: The size of the mini-batch. Only needed when momentum_type
- == 'qmodel' or when automatic adjustment is used. (Default: None)
- placement_strategy: string, Device placement strategy used when creating
- covariance variables, covariance ops, and inverse ops.
- (Default: `None`)
- **kwargs: Arguments to be passed to specific placement
- strategy mixin. Check `placement.RoundRobinPlacementMixin` for example.
-
- Raises:
- ValueError: If the momentum type is unsupported.
- ValueError: If clipping is used with momentum type other than 'regular'.
- ValueError: If no losses have been registered with layer_collection.
- ValueError: If momentum is non-zero and momentum_type is not 'regular'
- or 'adam'.
- """
- warnings.warn(
- "third_party.tensorflow.contrib.kfac is deprecated."
- "This will be removed on 15-07-2018. Check README for further details.",
- DeprecationWarning)
- # Parameters to be passed to the Fisher estimator:
- self._variables = var_list or tf_variables.trainable_variables
- self._cov_ema_decay = cov_ema_decay
- self._layers = layer_collection
- self._estimation_mode = estimation_mode
- self._colocate_gradients_with_ops = colocate_gradients_with_ops
-
- # The below parameters are required only if damping needs to be adapted.
- # These parameters can be set by calling
- # set_damping_adaptation_params() explicitly.
- self._damping_adaptation_decay = 0.95
- self._damping_adaptation_interval = 5
- # Check section 6.5 KFAC paper. omega(1) = pow(damping decay, interval)
- self._omega = (
- self._damping_adaptation_decay**self._damping_adaptation_interval)
- self._adapt_damping = False
- self._min_damping = 1e-5
- self._prev_train_batch = None
- self._is_chief = False
- self._loss_fn = None
- self._damping_constant = damping
- self._damping = None
- self._rho = None
- self._prev_loss = None
- self._q_model_change = None
- self._update_damping_op = None
-
- momentum_type = momentum_type.lower()
- legal_momentum_types = ["regular", "adam", "qmodel"]
-
- if momentum_type not in legal_momentum_types:
- raise ValueError("Unsupported momentum type {}. Must be one of {}."
- .format(momentum_type, legal_momentum_types))
- if momentum_type != "regular" and norm_constraint is not None:
- raise ValueError("Update clipping is only supported with momentum "
- "type 'regular'.")
- if momentum_type not in ["regular", "adam"] and momentum != 0:
- raise ValueError("Momentum must be unspecified if using a momentum_type "
- "other than 'regular' or 'adam'.")
-
- # Extra parameters of the optimizer
- self._momentum = momentum
- self._momentum_type = momentum_type
- self._norm_constraint = norm_constraint
- self._batch_size = batch_size
- self._placement_strategy = placement_strategy
-
- with variable_scope.variable_scope(name):
- self._fisher_est = est.make_fisher_estimator(
- placement_strategy=placement_strategy,
- variables=self._variables,
- cov_ema_decay=self._cov_ema_decay,
- damping=self.damping,
- layer_collection=self._layers,
- exps=(-1,),
- estimation_mode=self._estimation_mode,
- colocate_gradients_with_ops=self._colocate_gradients_with_ops,
- **kwargs)
-
- super(KfacOptimizer, self).__init__(learning_rate, name=name)
-
- def set_damping_adaptation_params(self,
- is_chief,
- prev_train_batch,
- loss_fn,
- min_damping=1e-5,
- damping_adaptation_decay=0.99,
- damping_adaptation_interval=5):
- """Sets parameters required to adapt damping during training.
-
- When called, enables damping adaptation according to the Levenberg-Marquardt
- style rule described in Section 6.5 of "Optimizing Neural Networks with
- Kronecker-factored Approximate Curvature".
-
- Note that this function creates Tensorflow variables which store a few
- scalars and are accessed by the ops which update the damping (as part
- of the training op returned by the minimize() method).
-
- Args:
- is_chief: `Boolean`, `True` if the worker is chief.
- prev_train_batch: Training data used to minimize loss in the previous
- step. This will be used to evaluate loss by calling
- `loss_fn(prev_train_batch)`.
- loss_fn: `function` that takes as input training data tensor and returns
- a scalar loss.
- min_damping: `float`(Optional), Minimum value the damping parameter
- can take. Default value 1e-5.
- damping_adaptation_decay: `float`(Optional), The `damping` parameter is
- multiplied by the `damping_adaptation_decay` every
- `damping_adaptation_interval` number of iterations. Default value 0.99.
- damping_adaptation_interval: `int`(Optional), Number of steps in between
- updating the `damping` parameter. Default value 5.
-
- Raises:
- ValueError: If `set_damping_adaptation_params` is already called and the
- the `adapt_damping` is `True`.
- """
- if self._adapt_damping:
- raise ValueError("Damping adaptation parameters already set.")
-
- with variable_scope.variable_scope(self.get_name()):
- self._adapt_damping = True
- self._is_chief = is_chief
- self._prev_train_batch = prev_train_batch
- self._loss_fn = loss_fn
- self._damping_adaptation_decay = damping_adaptation_decay
- self._damping_adaptation_interval = damping_adaptation_interval
- self._omega = (
- self._damping_adaptation_decay**self._damping_adaptation_interval)
- self._min_damping = min_damping
-
- self._rho = variable_scope.get_variable(
- "rho", shape=(), dtype=dtypes.float32, trainable=False) # LM ratio.
- self._prev_loss = variable_scope.get_variable(
- "prev_loss", shape=(), dtype=dtypes.float32, trainable=False)
- self._q_model_change = variable_scope.get_variable(
- "q_model_change", shape=(), dtype=dtypes.float32, trainable=False)
- self._damping = variable_scope.get_variable(
- "damping", initializer=self._damping_constant, trainable=False)
-
- @property
- def variables(self):
- return self._fisher_est.variables
-
- @property
- def damping(self):
- if self._damping:
- return self._damping
- else:
- return self._damping_constant
-
- @property
- def damping_adaptation_interval(self):
- return self._damping_adaptation_interval
-
- def make_vars_and_create_op_thunks(self):
- """Make vars and create op thunks.
-
- Returns:
- cov_update_thunks: List of cov update thunks. Corresponds one-to-one with
- the list of factors given by the "factors" property.
- inv_update_thunks: List of inv update thunks. Corresponds one-to-one with
- the list of factors given by the "factors" property.
- """
- scope = self.get_name() + "/" + self._fisher_est.name
- return self._fisher_est.make_vars_and_create_op_thunks(scope=scope)
-
- def create_ops_and_vars_thunks(self):
- """Create thunks that make the ops and vars on demand.
-
- This function returns 4 lists of thunks: cov_variable_thunks,
- cov_update_thunks, inv_variable_thunks, and inv_update_thunks.
-
- The length of each list is the number of factors and the i-th element of
- each list corresponds to the i-th factor (given by the "factors" property).
-
- Note that the execution of these thunks must happen in a certain
- partial order. The i-th element of cov_variable_thunks must execute
- before the i-th element of cov_update_thunks (and also the i-th element
- of inv_update_thunks). Similarly, the i-th element of inv_variable_thunks
- must execute before the i-th element of inv_update_thunks.
-
- TL;DR (oversimplified): Execute the thunks according to the order that
- they are returned.
-
- Returns:
- cov_variable_thunks: A list of thunks that make the cov variables.
- cov_update_thunks: A list of thunks that make the cov update ops.
- inv_variable_thunks: A list of thunks that make the inv variables.
- inv_update_thunks: A list of thunks that make the inv update ops.
- """
- scope = self.get_name() + "/" + self._fisher_est.name
- return self._fisher_est.create_ops_and_vars_thunks(scope=scope)
-
- def minimize(self, *args, **kwargs):
- # Should this variable scope encompass everything below? Or will the super-
- # class make another copy of the same name scope?
- with variable_scope.variable_scope(self.get_name()):
- kwargs["var_list"] = kwargs.get("var_list") or self.variables
- if set(kwargs["var_list"]) != set(self.variables):
- raise ValueError("var_list doesn't match with set of Fisher-estimating "
- "variables.")
- if self._adapt_damping and self._is_chief:
- global_step = kwargs.get("global_step", None)
- if not global_step:
- raise KeyError("global_step needs to be passed to optimizer.minimize "
- "if damping parameter is adapted.")
- update_damping_op = self._update_damping(self._prev_train_batch,
- global_step)
- with ops.control_dependencies([update_damping_op]):
- loss = args[0]
- loss_assign_op = state_ops.assign(self._prev_loss, loss)
- train_op = super(KfacOptimizer, self).minimize(*args, **kwargs)
- return control_flow_ops.group(loss_assign_op, train_op)
- else:
- return super(KfacOptimizer, self).minimize(*args, **kwargs)
-
- def compute_gradients(self, *args, **kwargs):
- # args[1] could be our var_list
- if len(args) > 1:
- var_list = args[1]
- else:
- kwargs["var_list"] = kwargs.get("var_list") or self.variables
- var_list = kwargs["var_list"]
-
- if set(var_list) != set(self.variables):
- raise ValueError("var_list doesn't match with set of Fisher-estimating "
- "variables.")
- return super(KfacOptimizer, self).compute_gradients(*args, **kwargs)
-
- def apply_gradients(self, grads_and_vars, *args, **kwargs):
- """Applies gradients to variables.
-
- Args:
- grads_and_vars: List of (gradient, variable) pairs.
- *args: Additional arguments for super.apply_gradients.
- **kwargs: Additional keyword arguments for super.apply_gradients.
-
- Returns:
- An `Operation` that applies the specified gradients.
- """
- # In Python 3, grads_and_vars can be a zip() object which can only be
- # iterated over once. By converting it to a list, we ensure that it can be
- # iterated over more than once.
- grads_and_vars = list(grads_and_vars)
-
- # Compute step.
- steps_and_vars = self._compute_update_steps(grads_and_vars)
-
- # Update trainable variables with this step.
- return super(KfacOptimizer, self).apply_gradients(steps_and_vars, *args,
- **kwargs)
-
- def _squared_fisher_norm(self, grads_and_vars, precon_grads_and_vars):
- """Computes the squared (approximate) Fisher norm of the updates.
-
- This is defined as v^T F v, where F is the approximate Fisher matrix
- as computed by the estimator, and v = F^{-1} g, where g is the gradient.
- This is computed efficiently as v^T g.
-
- Args:
- grads_and_vars: List of (gradient, variable) pairs.
- precon_grads_and_vars: List of (preconditioned gradient, variable) pairs.
- Must be the result of calling `self._fisher_est.multiply_inverse`
- on `grads_and_vars`.
-
- Returns:
- Scalar representing the squared norm.
-
- Raises:
- ValueError: if the two list arguments do not contain the same variables,
- in the same order.
- """
- for (_, gvar), (_, pgvar) in zip(grads_and_vars, precon_grads_and_vars):
- if gvar is not pgvar:
- raise ValueError("The variables referenced by the two arguments "
- "must match.")
- terms = [
- math_ops.reduce_sum(grad * pgrad)
- for (grad, _), (pgrad, _) in zip(grads_and_vars, precon_grads_and_vars)
- ]
- return math_ops.reduce_sum(terms)
-
- def _update_clip_coeff(self, grads_and_vars, precon_grads_and_vars):
- """Computes the scale factor for the update to satisfy the norm constraint.
-
- Defined as min(1, sqrt(c / r^T F r)), where c is the norm constraint,
- F is the approximate Fisher matrix, and r is the update vector, i.e.
- -alpha * v, where alpha is the learning rate, and v is the preconditioned
- gradient.
-
- This is based on Section 5 of Ba et al., Distributed Second-Order
- Optimization using Kronecker-Factored Approximations. Note that they
- absorb the learning rate alpha (which they denote eta_max) into the formula
- for the coefficient, while in our implementation, the rescaling is done
- before multiplying by alpha. Hence, our formula differs from theirs by a
- factor of alpha.
-
- Args:
- grads_and_vars: List of (gradient, variable) pairs.
- precon_grads_and_vars: List of (preconditioned gradient, variable) pairs.
- Must be the result of calling `self._fisher_est.multiply_inverse`
- on `grads_and_vars`.
-
- Returns:
- Scalar representing the coefficient which should be applied to the
- preconditioned gradients to satisfy the norm constraint.
- """
- sq_norm_grad = self._squared_fisher_norm(grads_and_vars,
- precon_grads_and_vars)
- sq_norm_up = sq_norm_grad * self._learning_rate**2
- return math_ops.minimum(1.,
- math_ops.sqrt(self._norm_constraint / sq_norm_up))
-
- def _clip_updates(self, grads_and_vars, precon_grads_and_vars):
- """Rescales the preconditioned gradients to satisfy the norm constraint.
-
- Rescales the preconditioned gradients such that the resulting update r
- (after multiplying by the learning rate) will satisfy the norm constraint.
- This constraint is that r^T F r <= C, where F is the approximate Fisher
- matrix, and C is the norm_constraint attribute. See Section 5 of
- Ba et al., Distributed Second-Order Optimization using Kronecker-Factored
- Approximations.
-
- Args:
- grads_and_vars: List of (gradient, variable) pairs.
- precon_grads_and_vars: List of (preconditioned gradient, variable) pairs.
- Must be the result of calling `self._fisher_est.multiply_inverse`
- on `grads_and_vars`.
-
- Returns:
- List of (rescaled preconditioned gradient, variable) pairs.
- """
- coeff = self._update_clip_coeff(grads_and_vars, precon_grads_and_vars)
- return [(pgrad * coeff, var) for pgrad, var in precon_grads_and_vars]
-
- def _compute_prev_updates(self, variables):
- """Computes previous updates as negative velocities scaled by learning rate.
-
- Args:
- variables: List of variables in the graph that the update will be
- applied to.
-
- Returns:
- List of previous updates applied to the `variables`.
- """
- return list(
- -1 * self._learning_rate * self._zeros_slot(var, "velocity", self._name)
- for var in variables)
-
- def _compute_qmodel_hyperparams(self, precon_grads, prev_updates, grads,
- variables):
- """Compute optimal update hyperparameters from the quadratic model.
-
- More specifically, if L is the loss we minimize a quadratic approximation
- of L(theta + d) which we denote by qmodel(d) with
- d = alpha*precon_grad + mu*prev_update with respect to alpha and mu, where
-
- qmodel(d) = (1/2) * d^T * B * d + grad^T*d + L(theta) .
-
- Unlike in the KL clipping approach we use the non-approximated quadratic
- model where the curvature matrix C is the true Fisher on the current
- mini-batch (computed without any approximations beyond mini-batch sampling),
- with the usual Tikhonov damping/regularization applied,
-
- C = F + damping * I
-
- See Section 7 of https://arxiv.org/abs/1503.05671 for a derivation of
- the formula. See Appendix C for a discussion of the trick of using
- a factorized Fisher matrix to more efficiently compute the required
- vector-matrix-vector products.
-
- Note that the elements of all 4 lists passed to this function must
- be in correspondence with each other.
-
- Args:
- precon_grads: List of preconditioned gradients.
- prev_updates: List of updates computed at the previous iteration.
- grads: List of gradients.
- variables: List of variables in the graph that the update will be
- applied to. (Note that this function doesn't actually apply the
- update.)
-
- Returns:
- (alpha, mu, qmodel_change), where alpha and mu are chosen to optimize the
- quadratic model, and
- qmodel_change = qmodel(alpha*precon_grad + mu*prev_update) - qmodel(0)
- = qmodel(alpha*precon_grad + mu*prev_update) - L(theta).
- """
-
- cmvpc = cmvp.CurvatureMatrixVectorProductComputer(self._layers.losses,
- variables)
-
- # compute the matrix-vector products with the transposed Fisher factor
- fft_precon_grads = cmvpc.multiply_fisher_factor_transpose(precon_grads)
- fft_prev_updates = cmvpc.multiply_fisher_factor_transpose(prev_updates)
- batch_size = math_ops.cast(
- self._batch_size, dtype=fft_precon_grads[0].dtype)
-
- # compute the entries of the 2x2 matrix
- m_11 = (
- _inner_product_list(fft_precon_grads, fft_precon_grads) / batch_size +
- self.damping * _inner_product_list(precon_grads, precon_grads))
-
- m_21 = (
- _inner_product_list(fft_prev_updates, fft_precon_grads) / batch_size +
- self.damping * _inner_product_list(prev_updates, precon_grads))
-
- m_22 = (
- _inner_product_list(fft_prev_updates, fft_prev_updates) / batch_size +
- self.damping * _inner_product_list(prev_updates, prev_updates))
-
- def non_zero_prevupd_case():
- r"""Computes optimal (alpha, mu) given non-zero previous update.
-
- We solve the full 2x2 linear system. See Martens & Grosse (2015),
- Section 7, definition of $\alpha^*$ and $\mu^*$.
-
- Returns:
- (alpha, mu, qmodel_change), where alpha and mu are chosen to optimize
- the quadratic model, and
- qmodel_change = qmodel(alpha*precon_grad + mu*prev_update) - qmodel(0).
- """
- m = ops.convert_to_tensor([[m_11, m_21], [m_21, m_22]])
-
- c = ops.convert_to_tensor([[_inner_product_list(grads, precon_grads)],
- [_inner_product_list(grads, prev_updates)]])
-
- sol = -1. * _two_by_two_solve(m, c)
- alpha = sol[0]
- mu = sol[1]
- qmodel_change = 0.5 * math_ops.reduce_sum(sol * c)
-
- return alpha, mu, qmodel_change
-
- def zero_prevupd_case():
- r"""Computes optimal (alpha, mu) given all-zero previous update.
-
- The linear system reduces to 1x1. See Martens & Grosse (2015),
- Section 6.4, definition of $\alpha^*$.
-
- Returns:
- (alpha, 0.0, qmodel_change), where alpha is chosen to optimize the
- quadratic model, and
- qmodel_change = qmodel(alpha*precon_grad) - qmodel(0)
- """
- m = m_11
- c = _inner_product_list(grads, precon_grads)
-
- alpha = -c / m
- mu = 0.0
- qmodel_change = 0.5 * alpha * c
-
- return alpha, mu, qmodel_change
-
- return control_flow_ops.cond(
- math_ops.equal(m_22, 0.0), zero_prevupd_case, non_zero_prevupd_case)
-
- def _assign_q_model_change(self, q_model_change):
- """Assigns `q_model_change` to `self._q_model_change` if damping is adapted.
-
- Note only the chief worker does the assignment.
-
- Args:
- q_model_change: Scalar tensor of type `float32`.
-
- Returns:
- If `adapt_damping` is `True` then returns an assign op, Otherwise returns
- a no_op().
- """
- if self._adapt_damping and self._is_chief:
- q_model_assign_op = state_ops.assign(self._q_model_change, q_model_change)
- else:
- q_model_assign_op = control_flow_ops.no_op()
- return q_model_assign_op
-
- def _compute_qmodel_hyperparams_wrapper(self, grads_and_vars,
- precon_grads_and_vars):
- """Wrapper function for `self._compute_qmodel_hyperparams`.
-
- Constructs a list of preconditioned gradients and variables. Also creates a
- op to assign the computed q model change to `self._q_model_change`.
-
- Args:
- grads_and_vars: List of (gradient, variable) pairs.
- precon_grads_and_vars: List of (preconditioned gradients, variable)
- pairs.
-
- Returns:
- (alpha, mu, q_model_assign_op), where alpha and mu are chosen to optimize
- the quadratic model, `q_model_assign_op` assigns the computed q model
- change to `self._q_model_change`.
- """
- precon_grads = list(
- precon_grad for (precon_grad, _) in precon_grads_and_vars)
- grads = list(grad for (grad, _) in grads_and_vars)
- variables = list(var for (_, var) in grads_and_vars)
- prev_updates = self._compute_prev_updates(variables)
- # Compute optimal velocity update parameters according to quadratic model
- alpha, mu, q_model_change = self._compute_qmodel_hyperparams(
- precon_grads, prev_updates, grads, variables)
-
- return alpha, mu, self._assign_q_model_change(q_model_change)
-
- def _compute_update_steps(self, grads_and_vars):
- """Computes the update steps for the variables given the gradients.
-
- Args:
- grads_and_vars: List of (gradient, variable) pairs.
-
- Returns:
- A list of tuple (assign_op ,var) where `assign_op` assigns the update
- steps to `var`.
- """
-
- if self._momentum_type == "regular":
- # Compute "preconditioned" gradient.
- precon_grads_and_vars = self._fisher_est.multiply_inverse(grads_and_vars)
-
- # Apply "KL clipping" if asked for.
- if self._norm_constraint is not None:
- precon_grads_and_vars = self._clip_updates(grads_and_vars,
- precon_grads_and_vars)
-
- # Update the velocity with this and return it as the step.
- if self._adapt_damping and self._is_chief:
- _, _, q_model_assign_op = self._compute_qmodel_hyperparams_wrapper(
- grads_and_vars, precon_grads_and_vars)
- with ops.control_dependencies([q_model_assign_op]):
- return self._update_velocities(precon_grads_and_vars, self._momentum)
- else:
- return self._update_velocities(precon_grads_and_vars, self._momentum)
- elif self._momentum_type == "adam":
- # Update velocity.
- velocities_and_vars = self._update_velocities(grads_and_vars,
- self._momentum)
- # Return "preconditioned" velocity vector as the step.
- return self._fisher_est.multiply_inverse(velocities_and_vars)
-
- elif self._momentum_type == "qmodel":
- # Compute "preconditioned" gradient.
- precon_grads_and_vars = self._fisher_est.multiply_inverse(grads_and_vars)
-
- # Compute optimal velocity update parameters according to quadratic model
- alpha, mu, q_model_assign_op = self._compute_qmodel_hyperparams_wrapper(
- grads_and_vars, precon_grads_and_vars)
-
- with ops.control_dependencies([q_model_assign_op]):
- return self._update_velocities(
- precon_grads_and_vars, mu, vec_coeff=-alpha)
-
- def _update_velocities(self, vecs_and_vars, decay, vec_coeff=1.0):
- """Updates the velocities of the variables with the given vectors.
-
- Args:
- vecs_and_vars: List of (vector, variable) pairs.
- decay: How much to decay the old velocity by. This is often referred to
- as the 'momentum constant'.
- vec_coeff: Coefficient to apply to the vectors before adding them to the
- velocity.
-
- Returns:
- A list of (velocity, var) indicating the new velocity for each var.
- """
-
- def _update_velocity(vec, var):
- velocity = self._zeros_slot(var, "velocity", self._name)
- with ops.colocate_with(velocity):
- # NOTE(mattjj): read/modify/write race condition not suitable for async.
-
- # Compute the new velocity for this variable.
- new_velocity = decay * velocity + vec_coeff * vec
-
- # Save the updated velocity.
- return (array_ops.identity(velocity.assign(new_velocity)), var)
-
- # Go through variable and update its associated part of the velocity vector.
- return [_update_velocity(vec, var) for vec, var in vecs_and_vars]
-
- def _update_damping(self, prev_batch, global_step):
- """Adapts damping parameter. Check KFAC (Section 6.5) for the details.
-
- The damping parameter is updated according to the Levenberg-Marquardt rule
- every `self._damping_adaptation_interval` iterations.
-
- Args:
- prev_batch: Tensor or tuple of tensors which can be passed to
- `self._loss_fn` to evaluate loss.
- global_step: `Variable` which keeps track of number of times the training
- variables have been updated.
- Returns:
- A `tf.cond` op which updates the damping parameter.
- """
- def compute_damping():
- """"Adapts damping parameter based on "reduction ratio".
-
- Reduction ratio captures how closely the quadratic approximation to the
- loss function approximates the actual loss within a trust region. The
- damping update tries to make the damping as small as possible while
- maintaining the property that the quadratic model remains a good local
- approximation to the loss function.
-
- Returns:
- An Op to assign newly computed damping value to `self._damping`.
- """
- prev_batch_loss = self._loss_fn(prev_batch)
- with ops.control_dependencies([prev_batch_loss]):
- rho_assign = self._rho.assign(
- (prev_batch_loss - self._prev_loss) / self._q_model_change)
- with ops.control_dependencies([rho_assign]):
- new_damping = control_flow_ops.case(
- [(self._rho < 0.25, lambda: self.damping / self._omega),
- (self._rho > 0.75, lambda: self.damping * self._omega)],
- lambda: self.damping)
- with ops.control_dependencies([new_damping]):
- new_damping_min = math_ops.maximum(new_damping, self._min_damping)
- return control_flow_ops.group(self._damping.assign(new_damping_min))
-
- return control_flow_ops.cond(
- math_ops.equal(
- math_ops.mod(global_step + 1, self._damping_adaptation_interval),
- 0), compute_damping, control_flow_ops.no_op)
-
-
-def _inner_product_list(list1, list2):
- return math_ops.add_n(
- [math_ops.reduce_sum(elt1 * elt2) for elt1, elt2 in zip(list1, list2)])
-
-
-def _two_by_two_solve(m, c):
- # it might be better just to crank out the exact formula for 2x2 inverses
- return math_ops.matmul(linalg_ops.matrix_inverse(m), c)
diff --git a/tensorflow/contrib/kfac/python/ops/optimizer_lib.py b/tensorflow/contrib/kfac/python/ops/optimizer_lib.py
deleted file mode 100644
index 87d1866e06..0000000000
--- a/tensorflow/contrib/kfac/python/ops/optimizer_lib.py
+++ /dev/null
@@ -1,30 +0,0 @@
-# Copyright 2017 The TensorFlow Authors. All Rights Reserved.
-#
-# Licensed under the Apache License, Version 2.0 (the "License");
-# you may not use this file except in compliance with the License.
-# You may obtain a copy of the License at
-#
-# http://www.apache.org/licenses/LICENSE-2.0
-#
-# Unless required by applicable law or agreed to in writing, software
-# distributed under the License is distributed on an "AS IS" BASIS,
-# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-# See the License for the specific language governing permissions and
-# limitations under the License.
-# ==============================================================================
-"""The KFAC optimizer."""
-
-from __future__ import absolute_import
-from __future__ import division
-from __future__ import print_function
-
-# pylint: disable=unused-import,line-too-long,wildcard-import
-from tensorflow.contrib.kfac.python.ops.optimizer import *
-from tensorflow.python.util.all_util import remove_undocumented
-# pylint: enable=unused-import,line-too-long,wildcard-import
-
-_allowed_symbols = [
- "KfacOptimizer",
-]
-
-remove_undocumented(__name__, allowed_exception_list=_allowed_symbols)
diff --git a/tensorflow/contrib/kfac/python/ops/placement.py b/tensorflow/contrib/kfac/python/ops/placement.py
deleted file mode 100644
index c4454325ae..0000000000
--- a/tensorflow/contrib/kfac/python/ops/placement.py
+++ /dev/null
@@ -1,114 +0,0 @@
-# Copyright 2017 The TensorFlow Authors. All Rights Reserved.
-#
-# Licensed under the Apache License, Version 2.0 (the "License");
-# you may not use this file except in compliance with the License.
-# You may obtain a copy of the License at
-#
-# http://www.apache.org/licenses/LICENSE-2.0
-#
-# Unless required by applicable law or agreed to in writing, software
-# distributed under the License is distributed on an "AS IS" BASIS,
-# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-# See the License for the specific language governing permissions and
-# limitations under the License.
-# ==============================================================================
-"""Implements placement strategies for cov and inv ops, cov variables."""
-
-from __future__ import absolute_import
-from __future__ import division
-from __future__ import print_function
-
-import itertools
-
-from tensorflow.python.framework import ops as tf_ops
-
-
-def _make_thunk_on_device(func, device):
- def thunk():
- with tf_ops.device(device):
- return func()
- return thunk
-
-
-class RoundRobinPlacementMixin(object):
- """Implements round robin placement strategy for ops and variables."""
-
- def __init__(self, cov_devices=None, inv_devices=None, **kwargs):
- """Initializes the RoundRobinPlacementMixin class.
-
- Args:
- cov_devices: Iterable of device strings (e.g. '/gpu:0'). Covariance
- computations will be placed on these devices in a round-robin fashion.
- Can be None, which means that no devices are specified.
- inv_devices: Iterable of device strings (e.g. '/gpu:0'). Inversion
- computations will be placed on these devices in a round-robin fashion.
- Can be None, which means that no devices are specified.
- **kwargs: Need something here?
-
- """
- super(RoundRobinPlacementMixin, self).__init__(**kwargs)
- self._cov_devices = cov_devices
- self._inv_devices = inv_devices
-
- def make_vars_and_create_op_thunks(self, scope=None):
- """Make vars and create op thunks w/ a round-robin device placement start.
-
- For each factor, all of that factor's cov variables and their associated
- update ops will be placed on a particular device. A new device is chosen
- for each factor by cycling through list of devices in the
- `self._cov_devices` attribute. If `self._cov_devices` is `Non`e then no
- explicit device placement occurs.
-
- An analogous strategy is followed for inverse update ops, with the list of
- devices being given by the `self._inv_devices` attribute.
-
- Inverse variables on the other hand are not placed on any specific device
- (they will just use the current the device placement context, whatever
- that happens to be). The idea is that the inverse variable belong where
- they will be accessed most often, which is the device that actually applies
- the preconditioner to the gradient. The user will be responsible for setting
- the device context for this.
-
- Args:
- scope: A string or None. If None it will be set to the name of this
- estimator (given by the name property). All variables will be created,
- and all thunks will execute, inside of a variable scope of the given
- name. (Default: None)
-
- Returns:
- cov_update_thunks: List of cov update thunks. Corresponds one-to-one with
- the list of factors given by the "factors" property.
- inv_update_thunks: List of inv update thunks. Corresponds one-to-one with
- the list of factors given by the "factors" property.
- """
- # Note: `create_ops_and_vars_thunks` is implemented in `FisherEstimator`.
- (cov_variable_thunks_raw, cov_update_thunks_raw, inv_variable_thunks_raw,
- inv_update_thunks_raw) = self.create_ops_and_vars_thunks(scope=scope)
-
- if self._cov_devices:
- cov_update_thunks = []
- for cov_variable_thunk, cov_update_thunk, device in zip(
- cov_variable_thunks_raw, cov_update_thunks_raw,
- itertools.cycle(self._cov_devices)):
- with tf_ops.device(device):
- cov_variable_thunk()
- cov_update_thunks.append(_make_thunk_on_device(cov_update_thunk,
- device))
- else:
- for cov_variable_thunk in cov_variable_thunks_raw:
- cov_variable_thunk()
- cov_update_thunks = cov_update_thunks_raw
-
- for inv_variable_thunk in inv_variable_thunks_raw:
- inv_variable_thunk()
-
- if self._inv_devices:
- inv_update_thunks = []
- for inv_update_thunk, device in zip(inv_update_thunks_raw,
- itertools.cycle(self._inv_devices)):
- inv_update_thunks.append(_make_thunk_on_device(inv_update_thunk,
- device))
- else:
- inv_update_thunks = inv_update_thunks_raw
-
- return cov_update_thunks, inv_update_thunks
diff --git a/tensorflow/contrib/kfac/python/ops/utils.py b/tensorflow/contrib/kfac/python/ops/utils.py
deleted file mode 100644
index 144295f4c7..0000000000
--- a/tensorflow/contrib/kfac/python/ops/utils.py
+++ /dev/null
@@ -1,709 +0,0 @@
-# Copyright 2017 The TensorFlow Authors. All Rights Reserved.
-#
-# Licensed under the Apache License, Version 2.0 (the "License");
-# you may not use this file except in compliance with the License.
-# You may obtain a copy of the License at
-#
-# http://www.apache.org/licenses/LICENSE-2.0
-#
-# Unless required by applicable law or agreed to in writing, software
-# distributed under the License is distributed on an "AS IS" BASIS,
-# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-# See the License for the specific language governing permissions and
-# limitations under the License.
-# ==============================================================================
-"""Utility functions."""
-
-from __future__ import absolute_import
-from __future__ import division
-from __future__ import print_function
-
-import numpy as np
-
-from tensorflow.contrib.tpu.python.ops import tpu_ops
-from tensorflow.contrib.tpu.python.tpu import tpu_function
-from tensorflow.python.framework import dtypes
-from tensorflow.python.framework import ops
-from tensorflow.python.framework import tensor_shape
-from tensorflow.python.ops import array_ops
-from tensorflow.python.ops import control_flow_ops
-from tensorflow.python.ops import gradients_impl
-from tensorflow.python.ops import linalg_ops
-from tensorflow.python.ops import math_ops
-from tensorflow.python.ops import nn_ops
-from tensorflow.python.ops import random_ops
-from tensorflow.python.ops import resource_variable_ops
-from tensorflow.python.ops import variables
-
-# Method used for inverting matrices.
-POSDEF_INV_METHOD = "cholesky"
-POSDEF_EIG_METHOD = "self_adjoint"
-
-
-def set_global_constants(posdef_inv_method=None):
- """Sets various global constants used by the classes in this module."""
- global POSDEF_INV_METHOD
-
- if posdef_inv_method is not None:
- POSDEF_INV_METHOD = posdef_inv_method
-
-
-class SequenceDict(object):
- """A dict convenience wrapper that allows getting/setting with sequences."""
-
- def __init__(self, iterable=None):
- self._dict = dict(iterable or [])
-
- def __getitem__(self, key_or_keys):
- if isinstance(key_or_keys, (tuple, list)):
- return list(map(self.__getitem__, key_or_keys))
- else:
- return self._dict[key_or_keys]
-
- def __setitem__(self, key_or_keys, val_or_vals):
- if isinstance(key_or_keys, (tuple, list)):
- for key, value in zip(key_or_keys, val_or_vals):
- self[key] = value
- else:
- self._dict[key_or_keys] = val_or_vals
-
- def items(self):
- return list(self._dict.items())
-
-
-def tensors_to_column(tensors):
- """Converts a tensor or list of tensors to a column vector.
-
- Args:
- tensors: A tensor or list of tensors.
-
- Returns:
- The tensors reshaped into vectors and stacked on top of each other.
- """
- if isinstance(tensors, (tuple, list)):
- return array_ops.concat(
- tuple(array_ops.reshape(tensor, [-1, 1]) for tensor in tensors), axis=0)
- else:
- return array_ops.reshape(tensors, [-1, 1])
-
-
-def column_to_tensors(tensors_template, colvec):
- """Converts a column vector back to the shape of the given template.
-
- Args:
- tensors_template: A tensor or list of tensors.
- colvec: A 2d column vector with the same shape as the value of
- tensors_to_column(tensors_template).
-
- Returns:
- X, where X is tensor or list of tensors with the properties:
- 1) tensors_to_column(X) = colvec
- 2) X (or its elements) have the same shape as tensors_template (or its
- elements)
- """
- if isinstance(tensors_template, (tuple, list)):
- offset = 0
- tensors = []
- for tensor_template in tensors_template:
- sz = np.prod(tensor_template.shape.as_list(), dtype=np.int32)
- tensor = array_ops.reshape(colvec[offset:(offset + sz)],
- tensor_template.shape)
- tensors.append(tensor)
- offset += sz
-
- tensors = tuple(tensors)
- else:
- tensors = array_ops.reshape(colvec, tensors_template.shape)
-
- return tensors
-
-
-def kronecker_product(mat1, mat2):
- """Computes the Kronecker product two matrices."""
- m1, n1 = mat1.get_shape().as_list()
- mat1_rsh = array_ops.reshape(mat1, [m1, 1, n1, 1])
- m2, n2 = mat2.get_shape().as_list()
- mat2_rsh = array_ops.reshape(mat2, [1, m2, 1, n2])
- return array_ops.reshape(mat1_rsh * mat2_rsh, [m1 * m2, n1 * n2])
-
-
-def layer_params_to_mat2d(vector):
- """Converts a vector shaped like layer parameters to a 2D matrix.
-
- In particular, we reshape the weights/filter component of the vector to be
- 2D, flattening all leading (input) dimensions. If there is a bias component,
- we concatenate it to the reshaped weights/filter component.
-
- Args:
- vector: A Tensor or pair of Tensors shaped like layer parameters.
-
- Returns:
- A 2D Tensor with the same coefficients and the same output dimension.
- """
- if isinstance(vector, (tuple, list)):
- w_part, b_part = vector
- w_part_reshaped = array_ops.reshape(w_part,
- [-1, w_part.shape.as_list()[-1]])
- return array_ops.concat(
- (w_part_reshaped, array_ops.reshape(b_part, [1, -1])), axis=0)
- elif isinstance(vector, ops.IndexedSlices):
- return vector
- else: # Tensor or Tensor-like.
- return array_ops.reshape(vector, [-1, vector.shape.as_list()[-1]])
-
-
-def mat2d_to_layer_params(vector_template, mat2d):
- """Converts a canonical 2D matrix representation back to a vector.
-
- Args:
- vector_template: A Tensor or pair of Tensors shaped like layer parameters.
- mat2d: A 2D Tensor with the same shape as the value of
- layer_params_to_mat2d(vector_template).
-
- Returns:
- A Tensor or pair of Tensors with the same coefficients as mat2d and the same
- shape as vector_template.
- """
- if isinstance(vector_template, (tuple, list)):
- w_part, b_part = mat2d[:-1], mat2d[-1]
- return array_ops.reshape(w_part, vector_template[0].shape), b_part
- elif isinstance(vector_template, ops.IndexedSlices):
- if not isinstance(mat2d, ops.IndexedSlices):
- raise TypeError(
- "If vector_template is an IndexedSlices, so should mat2d.")
- return mat2d
- else:
- return array_ops.reshape(mat2d, vector_template.shape)
-
-
-def posdef_inv(tensor, damping):
- """Computes the inverse of tensor + damping * identity."""
- identity = linalg_ops.eye(tensor.shape.as_list()[0], dtype=tensor.dtype)
- damping = math_ops.cast(damping, dtype=tensor.dtype)
- return posdef_inv_functions[POSDEF_INV_METHOD](tensor, identity, damping)
-
-
-def posdef_inv_matrix_inverse(tensor, identity, damping):
- """Computes inverse(tensor + damping * identity) directly."""
- return linalg_ops.matrix_inverse(tensor + damping * identity)
-
-
-def posdef_inv_cholesky(tensor, identity, damping):
- """Computes inverse(tensor + damping * identity) with Cholesky."""
- chol = linalg_ops.cholesky(tensor + damping * identity)
- return linalg_ops.cholesky_solve(chol, identity)
-
-
-def posdef_inv_eig(tensor, identity, damping):
- """Computes inverse(tensor + damping * identity) with eigendecomposition."""
- eigenvalues, eigenvectors = linalg_ops.self_adjoint_eig(
- tensor + damping * identity)
- return math_ops.matmul(
- eigenvectors / eigenvalues, eigenvectors, transpose_b=True)
-
-
-posdef_inv_functions = {
- "matrix_inverse": posdef_inv_matrix_inverse,
- "cholesky": posdef_inv_cholesky,
- "eig": posdef_inv_eig,
-}
-
-
-def posdef_eig(mat):
- """Computes the eigendecomposition of a positive semidefinite matrix."""
- return posdef_eig_functions[POSDEF_EIG_METHOD](mat)
-
-
-def posdef_eig_svd(mat):
- """Computes the singular values and left singular vectors of a matrix."""
- evals, evecs, _ = linalg_ops.svd(mat)
-
- return evals, evecs
-
-
-def posdef_eig_self_adjoint(mat):
- """Computes eigendecomposition using self_adjoint_eig."""
- evals, evecs = linalg_ops.self_adjoint_eig(mat)
- evals = math_ops.abs(evals) # Should be equivalent to svd approach.
-
- return evals, evecs
-
-
-posdef_eig_functions = {
- "self_adjoint": posdef_eig_self_adjoint,
- "svd": posdef_eig_svd,
-}
-
-
-def cholesky(tensor, damping):
- """Computes the inverse of tensor + damping * identity."""
- identity = linalg_ops.eye(tensor.shape.as_list()[0], dtype=tensor.dtype)
- damping = math_ops.cast(damping, dtype=tensor.dtype)
- return linalg_ops.cholesky(tensor + damping * identity)
-
-
-class SubGraph(object):
- """Defines a subgraph given by all the dependencies of a given set of outputs.
- """
-
- def __init__(self, outputs):
- # Set of all ancestor Tensors, Ops to 'outputs'.
- self._members = set()
-
- self._iter_add(outputs)
-
- def _iter_add(self, root):
- """Iteratively adds all of nodes' ancestors using depth first search."""
- stack = [root]
- while stack:
- nodes = stack.pop()
- for node in nodes:
- if node in self._members:
- continue
- self._members.add(node)
-
- if isinstance(node, ops.Tensor):
- stack.append((node.op,))
- elif isinstance(node, ops.Operation):
- stack.append(node.inputs)
-
- def is_member(self, node):
- """Check if 'node' is in this subgraph."""
- return node in self._members
-
- def variable_uses(self, var):
- """Computes number of times a variable is used.
-
- Args:
- var: Variable or ResourceVariable instance.
-
- Returns:
- Number of times a variable is used within this subgraph.
-
- Raises:
- ValueError: If 'var' is not a variable type.
- """
- if isinstance(var, resource_variable_ops.ResourceVariable):
- var = var.handle
- elif isinstance(var, variables.Variable):
- var = var.value()
- else:
- raise ValueError("%s does not appear to be a variable." % str(var))
-
- return len(self._members.intersection(set(var.consumers())))
-
- def filter_list(self, node_list):
- """Filters 'node_list' to nodes in this subgraph."""
- filtered_list = []
- for node in node_list:
- if self.is_member(node):
- filtered_list.append(node)
- return filtered_list
-
-
-def generate_random_signs(shape, dtype=dtypes.float32):
- """Generate a random tensor with {-1, +1} entries."""
- ints = random_ops.random_uniform(shape, maxval=2, dtype=dtypes.int32)
- return 2 * math_ops.cast(ints, dtype=dtype) - 1
-
-
-def fwd_gradients(ys, xs, grad_xs=None, stop_gradients=None):
- """Compute forward-mode gradients."""
- # See b/37888268.
-
- # This version of forward-mode autodiff is based on code by Tim Cooijmans
- # and handles list arguments and certain special cases such as when the
- # ys doesn't depend on one or more of the xs, and when ops.IndexedSlices are
- # generated by the first gradients_impl.gradients call.
-
- us = [array_ops.zeros_like(y) + float("nan") for y in ys]
- dydxs = gradients_impl.gradients(
- ys, xs, grad_ys=us, stop_gradients=stop_gradients)
-
- # Deal with strange types that gradients_impl.gradients returns but can't
- # deal with.
- dydxs = [
- ops.convert_to_tensor(dydx)
- if isinstance(dydx, ops.IndexedSlices) else dydx for dydx in dydxs
- ]
- dydxs = [
- array_ops.zeros_like(x) if dydx is None else dydx
- for x, dydx in zip(xs, dydxs)
- ]
-
- dysdx = gradients_impl.gradients(dydxs, us, grad_ys=grad_xs)
-
- return dysdx
-
-
-def on_tpu():
- """Returns True when building a TPU computation."""
- return tpu_function.get_tpu_context().number_of_shards is not None
-
-
-def cross_replica_mean(tensor, name=None):
- """Takes mean value of a Tensor across all TPU cores.
-
- Args:
- tensor: Tensor to be synchronized.
- name: None or string. Name of Op.
-
- Returns:
- Average of Tensor across all TPU cores.
-
- Raises:
- ValueError: If called outside of TPU context.
- """
- with ops.name_scope(name, "cross_replica_mean", [tensor]):
- num_shards = tpu_function.get_tpu_context().number_of_shards
- if num_shards is None:
- raise ValueError(
- "Cannot take cross_replica_mean() outside of TPU Context.")
- if num_shards == 1:
- return tensor
- return tpu_ops.cross_replica_sum(tensor / num_shards)
-
-
-def ensure_sequence(obj):
- """If `obj` isn't a tuple or list, return a tuple containing `obj`."""
- if isinstance(obj, (tuple, list)):
- return obj
- else:
- return (obj,)
-
-
-def batch_execute(global_step, thunks, batch_size, name=None):
- """Executes a subset of ops per global step.
-
- Given a list of thunks, each of which produces a single stateful op,
- ensures that exactly 'batch_size' ops are run per global step. Ops are
- scheduled in a round-robin fashion. For example, with 3 ops
-
- global_step | op0 | op1 | op2
- ------------+-----+-----+-----
- 0 | x | x |
- ------------+-----+-----+-----
- 1 | x | | x
- ------------+-----+-----+-----
- 2 | | x | x
- ------------+-----+-----+-----
- 3 | x | x |
- ------------+-----+-----+-----
- 4 | x | | x
-
- Does not guarantee order of op execution within a single global step.
-
- Args:
- global_step: Tensor indicating time. Determines which ops run.
- thunks: List of thunks. Each thunk encapsulates one op. Return values are
- ignored.
- batch_size: int. Number of ops to execute per global_step.
- name: string or None. Name scope for newly added ops.
-
- Returns:
- List of ops. Exactly 'batch_size' ops are guaranteed to have an effect
- every global step.
- """
-
- def true_fn(thunk):
- """Ensures thunk is executed and returns an Op (not a Tensor)."""
-
- def result():
- with ops.control_dependencies([thunk()]):
- return control_flow_ops.no_op()
-
- return result
-
- def false_fn(_):
- """Executes a no-op."""
-
- def result():
- return control_flow_ops.no_op()
-
- return result
-
- with ops.name_scope(name, "batch_execute"):
- true_fns = [true_fn(thunk) for thunk in thunks]
- false_fns = [false_fn(thunk) for thunk in thunks]
- num_thunks = len(thunks)
- conditions = [
- math_ops.less(
- math_ops.mod(batch_size - 1 + global_step * batch_size - j,
- num_thunks), batch_size) for j in range(num_thunks)
- ]
- result = [
- control_flow_ops.cond(condition, true_fn, false_fn)
- for (condition, true_fn,
- false_fn) in zip(conditions, true_fns, false_fns)
- ]
- return result
-
-
-def extract_convolution_patches(inputs,
- filter_shape,
- padding,
- strides=None,
- dilation_rate=None,
- name=None,
- data_format=None):
- """Extracts inputs to each output coordinate in tf.nn.convolution.
-
- This is a generalization of tf.extract_image_patches() to tf.nn.convolution(),
- where the number of spatial dimensions may be something other than 2.
-
- Assumes,
- - First dimension of inputs is batch_size
- - Convolution filter is applied to all input channels.
-
- Args:
- inputs: Tensor of shape [batch_size, ..spatial_image_shape..,
- ..spatial_filter_shape.., in_channels]. Inputs to tf.nn.convolution().
- filter_shape: List of ints. Shape of filter passed to tf.nn.convolution().
- padding: string. Padding method. One of "VALID", "SAME".
- strides: None or list of ints. Strides along spatial dimensions.
- dilation_rate: None or list of ints. Dilation along spatial dimensions.
- name: None or str. Name of Op.
- data_format: None or str. Format of data.
-
- Returns:
- Tensor of shape [batch_size, ..spatial_image_shape..,
- ..spatial_filter_shape.., in_channels]
-
- Raises:
- ValueError: If data_format does not put channel last.
- ValueError: If inputs and filter disagree on in_channels.
- """
- if not is_data_format_channel_last(data_format):
- raise ValueError("Channel must be last dimension.")
- with ops.name_scope(name, "extract_convolution_patches",
- [inputs, filter_shape, padding, strides, dilation_rate]):
- batch_size = inputs.shape.as_list()[0]
- in_channels = inputs.shape.as_list()[-1]
-
- # filter_shape = spatial_filter_shape + [in_channels, out_channels]
- spatial_filter_shape = filter_shape[:-2]
- if in_channels != filter_shape[-2]:
- raise ValueError("inputs and filter_shape must agree on in_channels.")
-
- # Map each input feature to a location in the output.
- out_channels = np.prod(spatial_filter_shape) * in_channels
- filters = linalg_ops.eye(out_channels)
- filters = array_ops.reshape(
- filters,
- list(spatial_filter_shape) + [in_channels, out_channels])
-
- result = nn_ops.convolution(
- inputs,
- filters,
- padding=padding,
- strides=strides,
- dilation_rate=dilation_rate)
- spatial_output_shape = result.shape.as_list()[1:-1]
- result = array_ops.reshape(result,
- [batch_size or -1] + spatial_output_shape +
- list(spatial_filter_shape) + [in_channels])
-
- return result
-
-
-def extract_pointwise_conv2d_patches(inputs,
- filter_shape,
- name=None,
- data_format=None):
- """Extract patches for a 1x1 conv2d.
-
- Args:
- inputs: 4-D Tensor of shape [batch_size, height, width, in_channels].
- filter_shape: List of 4 ints. Shape of filter to apply with conv2d()
- name: None or str. Name for Op.
- data_format: None or str. Format for data. See 'data_format' in
- tf.nn.conv2d() for details.
-
- Returns:
- Tensor of shape [batch_size, ..spatial_input_shape..,
- ..spatial_filter_shape.., in_channels]
-
- Raises:
- ValueError: if inputs is not 4-D.
- ValueError: if filter_shape is not [1, 1, ?, ?]
- ValueError: if data_format is not channels-last.
- """
- if inputs.shape.ndims != 4:
- raise ValueError("inputs must have 4 dims.")
- if len(filter_shape) != 4:
- raise ValueError("filter_shape must have 4 dims.")
- if filter_shape[0] != 1 or filter_shape[1] != 1:
- raise ValueError("filter_shape must have shape 1 along spatial dimensions.")
- if not is_data_format_channel_last(data_format):
- raise ValueError("data_format must be channels last.")
- with ops.name_scope(name, "extract_pointwise_conv2d_patches",
- [inputs, filter_shape]):
- ksizes = [1, 1, 1, 1] # Spatial shape is 1x1.
- strides = [1, 1, 1, 1] # Operate on all pixels.
- rates = [1, 1, 1, 1] # Dilation has no meaning with spatial shape = 1.
- padding = "VALID" # Doesn't matter.
- result = array_ops.extract_image_patches(inputs, ksizes, strides, rates,
- padding)
-
- batch_size, input_height, input_width, in_channels = inputs.shape.as_list()
- filter_height, filter_width, in_channels, _ = filter_shape
- return array_ops.reshape(result, [
- batch_size, input_height, input_width, filter_height, filter_width,
- in_channels
- ])
-
-
-def is_data_format_channel_last(data_format):
- """True if data_format puts channel last."""
- if data_format is None:
- return True
- return data_format.endswith("C")
-
-
-def matmul_sparse_dense(A, B, name=None, transpose_a=False, transpose_b=False): # pylint: disable=invalid-name
- """Computes matmul(A, B) where A is sparse, B is dense.
-
- Args:
- A: tf.IndexedSlices with dense shape [m, n].
- B: tf.Tensor with shape [n, k].
- name: str. Name of op.
- transpose_a: Bool. If true we transpose A before multiplying it by B.
- (Default: False)
- transpose_b: Bool. If true we transpose B before multiplying it by A.
- (Default: False)
-
- Returns:
- tf.IndexedSlices resulting from matmul(A, B).
-
- Raises:
- ValueError: If A doesn't represent a matrix.
- ValueError: If B is not rank-2.
- """
- with ops.name_scope(name, "matmul_sparse_dense", [A, B]):
- if A.indices.shape.ndims != 1 or A.values.shape.ndims != 2:
- raise ValueError("A must represent a matrix. Found: %s." % A)
- if B.shape.ndims != 2:
- raise ValueError("B must be a matrix.")
- new_values = math_ops.matmul(
- A.values, B, transpose_a=transpose_a, transpose_b=transpose_b)
- return ops.IndexedSlices(
- new_values,
- A.indices,
- dense_shape=array_ops.stack([A.dense_shape[0], new_values.shape[1]]))
-
-
-def matmul_diag_sparse(A_diag, B, name=None): # pylint: disable=invalid-name
- """Computes matmul(A, B) where A is a diagonal matrix, B is sparse.
-
- Args:
- A_diag: diagonal entries of matrix A of shape [m, m].
- B: tf.IndexedSlices. Represents matrix of shape [m, n].
- name: str. Name of op.
-
- Returns:
- tf.IndexedSlices resulting from matmul(A, B).
-
- Raises:
- ValueError: If A_diag is not rank-1.
- ValueError: If B doesn't represent a matrix.
- """
- with ops.name_scope(name, "matmul_diag_sparse", [A_diag, B]):
- A_diag = ops.convert_to_tensor(A_diag)
- if A_diag.shape.ndims != 1:
- raise ValueError("A_diag must be a rank-1 Tensor.")
- if B.indices.shape.ndims != 1 or B.values.shape.ndims != 2:
- raise ValueError("B must represent a matrix. Found: %s." % B)
- a = array_ops.gather(A_diag, B.indices)
- a = array_ops.reshape(a, list(a.shape) + [1] * (B.values.shape.ndims - 1))
- return ops.IndexedSlices(a * B.values, B.indices, dense_shape=B.dense_shape)
-
-
-class PartitionedTensor(object):
- """A Tensor partitioned across its 0-th dimension."""
-
- def __init__(self, tensors):
- """Initializes PartitionedTensor.
-
- Args:
- tensors: List of Tensors. All Tensors must agree on shape (excepting
- batch dimension) and dtype.
-
- Raises:
- ValueError: If 'tensors' has length zero.
- ValueError: if contents of 'tensors' don't agree on shape or dtype.
- """
- if not tensors:
- raise ValueError("tensors must be a list of 1+ Tensors.")
-
- dtype = tensors[0].dtype
- if not all(tensor.dtype == dtype for tensor in tensors):
- raise ValueError("all tensors must have dtype = %s." % dtype)
-
- shape = tensors[0].shape[1:]
- if not all(tensor.shape[1:] == shape for tensor in tensors):
- raise ValueError("All tensors must have shape = %s (excluding batch "
- "dimension)." % shape)
-
- self.tensors = tensors
- self._concats = {} # {device: Tensor}
-
- @property
- def shape(self):
- feature_shape = self.tensors[0].shape[1:]
- batch_size = sum([tensor.shape[0] for tensor in self.tensors],
- tensor_shape.Dimension(0))
- return tensor_shape.TensorShape([batch_size]).concatenate(feature_shape)
-
- def get_shape(self):
- return self.shape
-
- @property
- def dtype(self):
- return self.tensors[0].dtype
-
- def __str__(self):
- return "PartitionedTensor([%s, ...], dtype=%s, shape=%s)" % (
- self.tensors[0].name, self.dtype.name, tuple(self.shape.as_list()))
-
- def __hash__(self):
- return hash(tuple(self.tensors))
-
- def __eq__(self, other):
- if not isinstance(other, PartitionedTensor):
- return False
- return self.tensors == other.tensors
-
- def __ne__(self, other):
- return not self == other # pylint: disable=g-comparison-negation
-
- def __getitem__(self, key):
- return self.as_tensor()[key]
-
- def as_tensor(self, dtype=None, name=None, as_ref=False):
- with ops.name_scope(name, "PartitionedTensor.as_tensor", self.tensors):
- assert not as_ref
- assert dtype in [None, self.dtype]
- result = array_ops.concat(self.tensors, axis=0)
-
- # Cache 'result' if we haven't already cached a value for this device.
- if result.device not in self._concats:
- self._concats[result.device] = result
- return self._concats[result.device]
-
- @property
- def device(self):
- # PartitionedTensors in general do not live on a single device. If the
- # device cannot be determined unambiguously this property will return None.
- device = self.tensors[0].device
- if all(tensor.device == device for tensor in self.tensors):
- return device
- return None
-
-
-ops.register_tensor_conversion_function(
- PartitionedTensor,
- lambda val, dtype, name, as_ref: val.as_tensor(dtype, name, as_ref))
-
-
-# TODO(b/69623235): Add a function for finding tensors that share gradients
-# to eliminate redundant fisher factor computations.
diff --git a/tensorflow/contrib/kfac/python/ops/utils_lib.py b/tensorflow/contrib/kfac/python/ops/utils_lib.py
deleted file mode 100644
index 330d222dbf..0000000000
--- a/tensorflow/contrib/kfac/python/ops/utils_lib.py
+++ /dev/null
@@ -1,50 +0,0 @@
-# Copyright 2017 The TensorFlow Authors. All Rights Reserved.
-#
-# Licensed under the Apache License, Version 2.0 (the "License");
-# you may not use this file except in compliance with the License.
-# You may obtain a copy of the License at
-#
-# http://www.apache.org/licenses/LICENSE-2.0
-#
-# Unless required by applicable law or agreed to in writing, software
-# distributed under the License is distributed on an "AS IS" BASIS,
-# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-# See the License for the specific language governing permissions and
-# limitations under the License.
-# ==============================================================================
-"""Utility functions."""
-
-from __future__ import absolute_import
-from __future__ import division
-from __future__ import print_function
-
-# pylint: disable=unused-import,line-too-long,wildcard-import
-from tensorflow.contrib.kfac.python.ops.utils import *
-from tensorflow.python.util.all_util import remove_undocumented
-# pylint: enable=unused-import,line-too-long,wildcard-import
-
-_allowed_symbols = [
- "set_global_constants",
- "SequenceDict",
- "tensors_to_column",
- "column_to_tensors",
- "kronecker_product",
- "layer_params_to_mat2d",
- "mat2d_to_layer_params",
- "posdef_inv",
- "posdef_inv_matrix_inverse",
- "posdef_inv_cholesky",
- "posdef_inv_funcs",
- "SubGraph",
- "generate_random_signs",
- "fwd_gradients",
- "ensure_sequence",
- "batch_execute",
- "extract_convolution_patches",
- "extract_pointwise_conv2d_patches",
- "is_data_format_channel_last",
- "matmul_sparse_dense",
- "matmul_diag_sparse",
-]
-
-remove_undocumented(__name__, allowed_exception_list=_allowed_symbols)