diff options
author | Vikram Tankasali <tvikram@google.com> | 2018-08-22 12:51:59 -0700 |
---|---|---|
committer | TensorFlower Gardener <gardener@tensorflow.org> | 2018-08-22 12:56:27 -0700 |
commit | c73964210ced86791c9231768315fa4652abc9ba (patch) | |
tree | d3648366bd10b32e9b15eaa7d0dbcc545f95ae39 /tensorflow/contrib/kfac | |
parent | c85e0a9829258133c84e863b5c35be9e3f9aa280 (diff) |
BEGIN_PUBLIC
Delete tf.contrib.kfac. K-FAC in Tensorflow is now its own separate package.
END_PUBLIC
RELNOTES: n/a
Automated rollback of commit 938b9a40787028c58fb548fa6ada8c0dd8180f35
PiperOrigin-RevId: 209813506
Diffstat (limited to 'tensorflow/contrib/kfac')
46 files changed, 1 insertions, 14429 deletions
diff --git a/tensorflow/contrib/kfac/BUILD b/tensorflow/contrib/kfac/BUILD deleted file mode 100644 index b719046b37..0000000000 --- a/tensorflow/contrib/kfac/BUILD +++ /dev/null @@ -1,26 +0,0 @@ -# Description: -# Contains KfacOptimizer, an implementation of the K-FAC optimization -# algorithm in TensorFlow. -package(default_visibility = ["//visibility:public"]) - -licenses(["notice"]) # Apache 2.0 - -exports_files(["LICENSE"]) - -py_library( - name = "kfac", - srcs = ["__init__.py"], - srcs_version = "PY2AND3", - deps = [ - "//tensorflow/contrib/kfac/python/ops:curvature_matrix_vector_products_lib", - "//tensorflow/contrib/kfac/python/ops:fisher_blocks_lib", - "//tensorflow/contrib/kfac/python/ops:fisher_estimator_lib", - "//tensorflow/contrib/kfac/python/ops:fisher_factors_lib", - "//tensorflow/contrib/kfac/python/ops:kfac_optimizer_lib", - "//tensorflow/contrib/kfac/python/ops:layer_collection_lib", - "//tensorflow/contrib/kfac/python/ops:loss_functions_lib", - "//tensorflow/contrib/kfac/python/ops:op_queue_lib", - "//tensorflow/contrib/kfac/python/ops:utils_lib", - "//tensorflow/python:util", - ], -) diff --git a/tensorflow/contrib/kfac/README.md b/tensorflow/contrib/kfac/README.md index 102626925d..42b91d0313 100644 --- a/tensorflow/contrib/kfac/README.md +++ b/tensorflow/contrib/kfac/README.md @@ -1,94 +1,3 @@ # K-FAC: Kronecker-Factored Approximate Curvature -# <font color="red", size=10><u>WARNING: </u></font> -# ==third_party/tensorflow/contrib/kfac is deprecated. This will be== -# ==removed on 15-07-2018. <!-- STY:begin_strip_and_replace -->Please import third_party/tensorflow_kfac.== -# ==<!-- STY:end_strip_and_replace Please check https://github.com/tensorflow/kfac. -->== - -**K-FAC in TensorFlow** is an implementation of [K-FAC][kfac-paper], an -approximate second-order optimization method, in TensorFlow. When applied to -feedforward and convolutional neural networks, K-FAC can converge `>3.5x` -faster in `>14x` fewer iterations than SGD with Momentum. - -[kfac-paper]: https://arxiv.org/abs/1503.05671 - -## What is K-FAC? - -K-FAC, short for "Kronecker-factored Approximate Curvature", is an approximation -to the [Natural Gradient][natural_gradient] algorithm designed specifically for -neural networks. It maintains a block-diagonal approximation to the [Fisher -Information matrix][fisher_information], whose inverse preconditions the -gradient. - -K-FAC can be used in place of SGD, Adam, and other `Optimizer` implementations. -Experimentally, K-FAC converges `>3.5x` faster than well-tuned SGD. - -Unlike most optimizers, K-FAC exploits structure in the model itself (e.g. "What -are the weights for layer i?"). As such, you must add some additional code while -constructing your model to use K-FAC. - -[natural_gradient]: http://www.mitpressjournals.org/doi/abs/10.1162/089976698300017746 -[fisher_information]: https://en.wikipedia.org/wiki/Fisher_information#Matrix_form - -## Why should I use K-FAC? - -K-FAC can take advantage of the curvature of the optimization problem, resulting -in **faster training**. For an 8-layer Autoencoder, K-FAC converges to the same -loss as SGD with Momentum in 3.8x fewer seconds and 14.7x fewer updates. See how -training loss changes as a function of number of epochs, steps, and seconds: - -![autoencoder](g3doc/autoencoder.png) - -## Is K-FAC for me? - -If you have a feedforward or convolutional model for classification that is -converging too slowly, K-FAC is for you. K-FAC can be used in your model if: - -* Your model defines a posterior distribution. -* Your model uses only fully-connected or convolutional layers (residual - connections OK). -* You are training on CPU or GPU. -* You can modify model code to register layers with K-FAC. - -## How do I use K-FAC? - -Using K-FAC requires three steps: - -1. Registering layer inputs, weights, and pre-activations with a - `LayerCollection`. -1. Minimizing the loss with a `KfacOptimizer`. -1. Keeping K-FAC's preconditioner updated. - -```python -# Build model. -w = tf.get_variable("w", ...) -b = tf.get_variable("b", ...) -logits = tf.matmul(x, w) + b -loss = tf.reduce_mean( - tf.nn.softmax_cross_entropy_with_logits(labels=y, logits=logits)) - -# Register layers. -layer_collection = LayerCollection() -layer_collection.register_fully_connected((w, b), x, logits) -layer_collection.register_categorical_predictive_distribution(logits) - -# Construct training ops. -optimizer = KfacOptimizer(..., layer_collection=layer_collection) -train_op = optimizer.minimize(loss) - -# Minimize loss. -with tf.Session() as sess: - ... - sess.run([train_op, optimizer.cov_update_op, optimizer.inv_update_op]) -``` - -See [`examples/`](https://www.tensorflow.org/code/tensorflow/contrib/kfac/examples/) for runnable, end-to-end illustrations. - -## Authors - -- Alok Aggarwal -- Daniel Duckworth -- James Martens -- Matthew Johnson -- Olga Wichrowska -- Roger Grosse +## KFAC moved to third_party/tensorflow_kfac. diff --git a/tensorflow/contrib/kfac/__init__.py b/tensorflow/contrib/kfac/__init__.py deleted file mode 100644 index 1ea354e6cd..0000000000 --- a/tensorflow/contrib/kfac/__init__.py +++ /dev/null @@ -1,46 +0,0 @@ -# Copyright 2017 The TensorFlow Authors. All Rights Reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -# ============================================================================== -"""Kronecker-factored Approximate Curvature Optimizer.""" - -from __future__ import absolute_import -from __future__ import division -from __future__ import print_function - -# pylint: disable=unused-import,line-too-long -from tensorflow.contrib.kfac.python.ops import curvature_matrix_vector_products_lib as curvature_matrix_vector_products -from tensorflow.contrib.kfac.python.ops import estimator_lib as estimator -from tensorflow.contrib.kfac.python.ops import fisher_blocks_lib as fisher_blocks -from tensorflow.contrib.kfac.python.ops import fisher_factors_lib as fisher_factors -from tensorflow.contrib.kfac.python.ops import layer_collection_lib as layer_collection -from tensorflow.contrib.kfac.python.ops import loss_functions_lib as loss_functions -from tensorflow.contrib.kfac.python.ops import op_queue_lib as op_queue -from tensorflow.contrib.kfac.python.ops import optimizer_lib as optimizer -from tensorflow.contrib.kfac.python.ops import utils_lib as utils -from tensorflow.python.util.all_util import remove_undocumented -# pylint: enable=unused-import,line-too-long - -_allowed_symbols = [ - "curvature_matrix_vector_products", - "estimator", - "fisher_blocks", - "fisher_factors", - "layer_collection", - "loss_functions", - "op_queue", - "optimizer", - "utils", -] - -remove_undocumented(__name__, allowed_exception_list=_allowed_symbols) diff --git a/tensorflow/contrib/kfac/examples/BUILD b/tensorflow/contrib/kfac/examples/BUILD deleted file mode 100644 index 8186fa1c62..0000000000 --- a/tensorflow/contrib/kfac/examples/BUILD +++ /dev/null @@ -1,80 +0,0 @@ -package(default_visibility = [ - "//learning/brain/contrib/kfac/examples:__subpackages__", - "//tensorflow/contrib/kfac/examples:__subpackages__", -]) - -licenses(["notice"]) # Apache 2.0 - -exports_files(["LICENSE"]) - -py_binary( - name = "mlp_mnist_main", - srcs = ["mlp_mnist_main.py"], - srcs_version = "PY2AND3", - deps = [ - ":mlp", - "//tensorflow:tensorflow_py", - ], -) - -py_library( - name = "mlp", - srcs = ["mlp.py"], - srcs_version = "PY2AND3", - deps = [ - ":mnist", - "//tensorflow:tensorflow_py", - ], -) - -py_binary( - name = "convnet_mnist_single_main", - srcs = ["convnet_mnist_single_main.py"], - srcs_version = "PY2AND3", - deps = [ - ":convnet", - "//tensorflow:tensorflow_py", - ], -) - -py_binary( - name = "convnet_mnist_multi_tower_main", - srcs = ["convnet_mnist_multi_tower_main.py"], - srcs_version = "PY2AND3", - deps = [ - ":convnet", - "//tensorflow:tensorflow_py", - ], -) - -py_binary( - name = "convnet_mnist_distributed_main", - srcs = ["convnet_mnist_distributed_main.py"], - srcs_version = "PY2AND3", - deps = [ - ":convnet", - "//tensorflow:tensorflow_py", - ], -) - -py_library( - name = "convnet", - srcs = ["convnet.py"], - srcs_version = "PY2AND3", - deps = [ - ":mlp", - ":mnist", - "//tensorflow:tensorflow_py", - "//third_party/py/numpy", - ], -) - -py_library( - name = "mnist", - srcs = ["mnist.py"], - srcs_version = "PY2AND3", - deps = [ - "//tensorflow:tensorflow_py", - "//third_party/py/numpy", - ], -) diff --git a/tensorflow/contrib/kfac/examples/convnet.py b/tensorflow/contrib/kfac/examples/convnet.py deleted file mode 100644 index 44e01e1aeb..0000000000 --- a/tensorflow/contrib/kfac/examples/convnet.py +++ /dev/null @@ -1,667 +0,0 @@ -# Copyright 2017 The TensorFlow Authors. All Rights Reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -# ============================================================================== -r"""Train a ConvNet on MNIST using K-FAC. - -This library fits a 5-layer ConvNet on MNIST using K-FAC. The model has the -following structure, - -- Conv Layer: 5x5 kernel, 16 output channels. -- Max Pool: 3x3 kernel, stride 2. -- Conv Layer: 5x5 kernel, 16 output channels. -- Max Pool: 3x3 kernel, stride 2. -- Linear: 10 output dims. - -After 3k~6k steps, this should reach perfect accuracy on the training set. -""" - -from __future__ import absolute_import -from __future__ import division -from __future__ import print_function - -import os - -import numpy as np -import tensorflow as tf - -from tensorflow.contrib.kfac.examples import mlp -from tensorflow.contrib.kfac.examples import mnist -from tensorflow.contrib.kfac.python.ops import optimizer as opt - - -lc = tf.contrib.kfac.layer_collection -oq = tf.contrib.kfac.op_queue -opt = tf.contrib.kfac.optimizer - -__all__ = [ - "conv_layer", - "max_pool_layer", - "linear_layer", - "build_model", - "minimize_loss_single_machine", - "distributed_grads_only_and_ops_chief_worker", - "distributed_grads_and_ops_dedicated_workers", - "train_mnist_single_machine", - "train_mnist_distributed_sync_replicas", - "train_mnist_multitower" -] - - -# Inverse update ops will be run every _INVERT_EVRY iterations. -_INVERT_EVERY = 10 - - -def conv_layer(layer_id, inputs, kernel_size, out_channels): - """Builds a convolutional layer with ReLU non-linearity. - - Args: - layer_id: int. Integer ID for this layer's variables. - inputs: Tensor of shape [num_examples, width, height, in_channels]. Each row - corresponds to a single example. - kernel_size: int. Width and height of the convolution kernel. The kernel is - assumed to be square. - out_channels: int. Number of output features per pixel. - - Returns: - preactivations: Tensor of shape [num_examples, width, height, out_channels]. - Values of the layer immediately before the activation function. - activations: Tensor of shape [num_examples, width, height, out_channels]. - Values of the layer immediately after the activation function. - params: Tuple of (kernel, bias), parameters for this layer. - """ - # TODO(b/67004004): Delete this function and rely on tf.layers exclusively. - layer = tf.layers.Conv2D( - out_channels, - kernel_size=[kernel_size, kernel_size], - kernel_initializer=tf.random_normal_initializer(stddev=0.01), - padding="SAME", - name="conv_%d" % layer_id) - preactivations = layer(inputs) - activations = tf.nn.relu(preactivations) - - # layer.weights is a list. This converts it a (hashable) tuple. - return preactivations, activations, (layer.kernel, layer.bias) - - -def max_pool_layer(layer_id, inputs, kernel_size, stride): - """Build a max-pooling layer. - - Args: - layer_id: int. Integer ID for this layer's variables. - inputs: Tensor of shape [num_examples, width, height, in_channels]. Each row - corresponds to a single example. - kernel_size: int. Width and height to pool over per input channel. The - kernel is assumed to be square. - stride: int. Step size between pooling operations. - - Returns: - Tensor of shape [num_examples, width/stride, height/stride, out_channels]. - Result of applying max pooling to 'inputs'. - """ - # TODO(b/67004004): Delete this function and rely on tf.layers exclusively. - with tf.variable_scope("pool_%d" % layer_id): - return tf.nn.max_pool( - inputs, [1, kernel_size, kernel_size, 1], [1, stride, stride, 1], - padding="SAME", - name="pool") - - -def linear_layer(layer_id, inputs, output_size): - """Builds the final linear layer for an MNIST classification problem. - - Args: - layer_id: int. Integer ID for this layer's variables. - inputs: Tensor of shape [num_examples, width, height, in_channels]. Each row - corresponds to a single example. - output_size: int. Number of output dims per example. - - Returns: - activations: Tensor of shape [num_examples, output_size]. Values of the - layer immediately after the activation function. - params: Tuple of (weights, bias), parameters for this layer. - """ - # TODO(b/67004004): Delete this function and rely on tf.layers exclusively. - pre, _, params = mlp.fc_layer(layer_id, inputs, output_size) - return pre, params - - -def build_model(examples, labels, num_labels, layer_collection): - """Builds a ConvNet classification model. - - Args: - examples: Tensor of shape [num_examples, num_features]. Represents inputs of - model. - labels: Tensor of shape [num_examples]. Contains integer IDs to be predicted - by softmax for each example. - num_labels: int. Number of distinct values 'labels' can take on. - layer_collection: LayerCollection instance. Layers will be registered here. - - Returns: - loss: 0-D Tensor representing loss to be minimized. - accuracy: 0-D Tensor representing model's accuracy. - """ - # Build a ConvNet. For each layer with parameters, we'll keep track of the - # preactivations, activations, weights, and bias. - tf.logging.info("Building model.") - pre0, act0, params0 = conv_layer( - layer_id=0, inputs=examples, kernel_size=5, out_channels=16) - act1 = max_pool_layer(layer_id=1, inputs=act0, kernel_size=3, stride=2) - pre2, act2, params2 = conv_layer( - layer_id=2, inputs=act1, kernel_size=5, out_channels=16) - act3 = max_pool_layer(layer_id=3, inputs=act2, kernel_size=3, stride=2) - flat_act3 = tf.reshape(act3, shape=[-1, int(np.prod(act3.shape[1:4]))]) - logits, params4 = linear_layer( - layer_id=4, inputs=flat_act3, output_size=num_labels) - loss = tf.reduce_mean( - tf.nn.sparse_softmax_cross_entropy_with_logits( - labels=labels, logits=logits)) - accuracy = tf.reduce_mean( - tf.cast(tf.equal(labels, tf.argmax(logits, axis=1)), dtype=tf.float32)) - - with tf.device("/cpu:0"): - tf.summary.scalar("loss", loss) - tf.summary.scalar("accuracy", accuracy) - - # Register parameters. K-FAC needs to know about the inputs, outputs, and - # parameters of each conv/fully connected layer and the logits powering the - # posterior probability over classes. - tf.logging.info("Building LayerCollection.") - layer_collection.register_conv2d(params0, (1, 1, 1, 1), "SAME", examples, - pre0) - layer_collection.register_conv2d(params2, (1, 1, 1, 1), "SAME", act1, pre2) - layer_collection.register_fully_connected(params4, flat_act3, logits) - layer_collection.register_categorical_predictive_distribution( - logits, name="logits") - - return loss, accuracy - - -def minimize_loss_single_machine(loss, - accuracy, - layer_collection, - device="/gpu:0", - session_config=None): - """Minimize loss with K-FAC on a single machine. - - A single Session is responsible for running all of K-FAC's ops. The covariance - and inverse update ops are placed on `device`. All model variables are on CPU. - - Args: - loss: 0-D Tensor. Loss to be minimized. - accuracy: 0-D Tensor. Accuracy of classifier on current minibatch. - layer_collection: LayerCollection instance describing model architecture. - Used by K-FAC to construct preconditioner. - device: string, Either '/cpu:0' or '/gpu:0'. The covariance and inverse - update ops are run on this device. - session_config: None or tf.ConfigProto. Configuration for tf.Session(). - - Returns: - final value for 'accuracy'. - """ - # Train with K-FAC. - g_step = tf.train.get_or_create_global_step() - optimizer = opt.KfacOptimizer( - learning_rate=0.0001, - cov_ema_decay=0.95, - damping=0.001, - layer_collection=layer_collection, - placement_strategy="round_robin", - cov_devices=[device], - inv_devices=[device], - momentum=0.9) - (cov_update_thunks, - inv_update_thunks) = optimizer.make_vars_and_create_op_thunks() - - def make_update_op(update_thunks): - update_ops = [thunk() for thunk in update_thunks] - return tf.group(*update_ops) - - cov_update_op = make_update_op(cov_update_thunks) - with tf.control_dependencies([cov_update_op]): - inverse_op = tf.cond( - tf.equal(tf.mod(g_step, _INVERT_EVERY), 0), - lambda: make_update_op(inv_update_thunks), tf.no_op) - with tf.control_dependencies([inverse_op]): - with tf.device(device): - train_op = optimizer.minimize(loss, global_step=g_step) - - tf.logging.info("Starting training.") - with tf.train.MonitoredTrainingSession(config=session_config) as sess: - while not sess.should_stop(): - global_step_, loss_, accuracy_, _ = sess.run( - [g_step, loss, accuracy, train_op]) - - if global_step_ % _INVERT_EVERY == 0: - tf.logging.info("global_step: %d | loss: %f | accuracy: %s", - global_step_, loss_, accuracy_) - - return accuracy_ - - -def _is_gradient_task(task_id, num_tasks): - """Returns True if this task should update the weights.""" - if num_tasks < 3: - return True - return 0 <= task_id < 0.6 * num_tasks - - -def _is_cov_update_task(task_id, num_tasks): - """Returns True if this task should update K-FAC's covariance matrices.""" - if num_tasks < 3: - return False - return 0.6 * num_tasks <= task_id < num_tasks - 1 - - -def _is_inv_update_task(task_id, num_tasks): - """Returns True if this task should update K-FAC's preconditioner.""" - if num_tasks < 3: - return False - return task_id == num_tasks - 1 - - -def _num_gradient_tasks(num_tasks): - """Number of tasks that will update weights.""" - if num_tasks < 3: - return num_tasks - return int(np.ceil(0.6 * num_tasks)) - - -def _make_distributed_train_op( - task_id, - num_worker_tasks, - num_ps_tasks, - layer_collection -): - """Creates optimizer and distributed training op. - - Constructs KFAC optimizer and wraps it in `sync_replicas` optimizer. Makes - the train op. - - Args: - task_id: int. Integer in [0, num_worker_tasks). ID for this worker. - num_worker_tasks: int. Number of workers in this distributed training setup. - num_ps_tasks: int. Number of parameter servers holding variables. If 0, - parameter servers are not used. - layer_collection: LayerCollection instance describing model architecture. - Used by K-FAC to construct preconditioner. - - Returns: - sync_optimizer: `tf.train.SyncReplicasOptimizer` instance which wraps KFAC - optimizer. - optimizer: Instance of `opt.KfacOptimizer`. - global_step: `tensor`, Global step. - """ - tf.logging.info("Task id : %d", task_id) - with tf.device(tf.train.replica_device_setter(num_ps_tasks)): - global_step = tf.train.get_or_create_global_step() - optimizer = opt.KfacOptimizer( - learning_rate=0.0001, - cov_ema_decay=0.95, - damping=0.001, - layer_collection=layer_collection, - momentum=0.9) - sync_optimizer = tf.train.SyncReplicasOptimizer( - opt=optimizer, - replicas_to_aggregate=_num_gradient_tasks(num_worker_tasks), - total_num_replicas=num_worker_tasks) - return sync_optimizer, optimizer, global_step - - -def distributed_grads_only_and_ops_chief_worker( - task_id, is_chief, num_worker_tasks, num_ps_tasks, master, checkpoint_dir, - loss, accuracy, layer_collection, invert_every=10): - """Minimize loss with a synchronous implementation of K-FAC. - - All workers perform gradient computation. Chief worker applies gradient after - averaging the gradients obtained from all the workers. All workers block - execution until the update is applied. Chief worker runs covariance and - inverse update ops. Covariance and inverse matrices are placed on parameter - servers in a round robin manner. For further details on synchronous - distributed optimization check `tf.train.SyncReplicasOptimizer`. - - Args: - task_id: int. Integer in [0, num_worker_tasks). ID for this worker. - is_chief: `boolean`, `True` if the worker is chief worker. - num_worker_tasks: int. Number of workers in this distributed training setup. - num_ps_tasks: int. Number of parameter servers holding variables. If 0, - parameter servers are not used. - master: string. IP and port of TensorFlow runtime process. Set to empty - string to run locally. - checkpoint_dir: string or None. Path to store checkpoints under. - loss: 0-D Tensor. Loss to be minimized. - accuracy: dict mapping strings to 0-D Tensors. Additional accuracy to - run with each step. - layer_collection: LayerCollection instance describing model architecture. - Used by K-FAC to construct preconditioner. - invert_every: `int`, Number of steps between update the inverse. - - Returns: - final value for 'accuracy'. - - Raises: - ValueError: if task_id >= num_worker_tasks. - """ - - sync_optimizer, optimizer, global_step = _make_distributed_train_op( - task_id, num_worker_tasks, num_ps_tasks, layer_collection) - (cov_update_thunks, - inv_update_thunks) = optimizer.make_vars_and_create_op_thunks() - - tf.logging.info("Starting training.") - hooks = [sync_optimizer.make_session_run_hook(is_chief)] - - def make_update_op(update_thunks): - update_ops = [thunk() for thunk in update_thunks] - return tf.group(*update_ops) - - if is_chief: - cov_update_op = make_update_op(cov_update_thunks) - with tf.control_dependencies([cov_update_op]): - inverse_op = tf.cond( - tf.equal(tf.mod(global_step, invert_every), 0), - lambda: make_update_op(inv_update_thunks), - tf.no_op) - with tf.control_dependencies([inverse_op]): - train_op = sync_optimizer.minimize(loss, global_step=global_step) - else: - train_op = sync_optimizer.minimize(loss, global_step=global_step) - - with tf.train.MonitoredTrainingSession( - master=master, - is_chief=is_chief, - checkpoint_dir=checkpoint_dir, - hooks=hooks, - stop_grace_period_secs=0) as sess: - while not sess.should_stop(): - global_step_, loss_, accuracy_, _ = sess.run( - [global_step, loss, accuracy, train_op]) - tf.logging.info("global_step: %d | loss: %f | accuracy: %s", global_step_, - loss_, accuracy_) - return accuracy_ - - -def distributed_grads_and_ops_dedicated_workers( - task_id, is_chief, num_worker_tasks, num_ps_tasks, master, checkpoint_dir, - loss, accuracy, layer_collection): - """Minimize loss with a synchronous implementation of K-FAC. - - Different workers are responsible for different parts of K-FAC's Ops. The - first 60% of tasks compute gradients; the next 20% accumulate covariance - statistics; the last 20% invert the matrices used to precondition gradients. - The chief worker applies the gradient . - - Args: - task_id: int. Integer in [0, num_worker_tasks). ID for this worker. - is_chief: `boolean`, `True` if the worker is chief worker. - num_worker_tasks: int. Number of workers in this distributed training setup. - num_ps_tasks: int. Number of parameter servers holding variables. If 0, - parameter servers are not used. - master: string. IP and port of TensorFlow runtime process. Set to empty - string to run locally. - checkpoint_dir: string or None. Path to store checkpoints under. - loss: 0-D Tensor. Loss to be minimized. - accuracy: dict mapping strings to 0-D Tensors. Additional accuracy to - run with each step. - layer_collection: LayerCollection instance describing model architecture. - Used by K-FAC to construct preconditioner. - - Returns: - final value for 'accuracy'. - - Raises: - ValueError: if task_id >= num_worker_tasks. - """ - sync_optimizer, optimizer, global_step = _make_distributed_train_op( - task_id, num_worker_tasks, num_ps_tasks, layer_collection) - _, cov_update_op, inv_update_ops, _, _, _ = optimizer.make_ops_and_vars() - train_op = sync_optimizer.minimize(loss, global_step=global_step) - inv_update_queue = oq.OpQueue(inv_update_ops) - - tf.logging.info("Starting training.") - is_chief = (task_id == 0) - hooks = [sync_optimizer.make_session_run_hook(is_chief)] - with tf.train.MonitoredTrainingSession( - master=master, - is_chief=is_chief, - checkpoint_dir=checkpoint_dir, - hooks=hooks, - stop_grace_period_secs=0) as sess: - while not sess.should_stop(): - # Choose which op this task is responsible for running. - if _is_gradient_task(task_id, num_worker_tasks): - learning_op = train_op - elif _is_cov_update_task(task_id, num_worker_tasks): - learning_op = cov_update_op - elif _is_inv_update_task(task_id, num_worker_tasks): - # TODO(duckworthd): Running this op before cov_update_op has been run a - # few times can result in "InvalidArgumentError: Cholesky decomposition - # was not successful." Delay running this op until cov_update_op has - # been run a few times. - learning_op = inv_update_queue.next_op(sess) - else: - raise ValueError("Which op should task %d do?" % task_id) - - global_step_, loss_, accuracy_, _ = sess.run( - [global_step, loss, accuracy, learning_op]) - tf.logging.info("global_step: %d | loss: %f | accuracy: %s", global_step_, - loss_, accuracy_) - - return accuracy_ - - -def train_mnist_single_machine(data_dir, - num_epochs, - use_fake_data=False, - device="/gpu:0"): - """Train a ConvNet on MNIST. - - Args: - data_dir: string. Directory to read MNIST examples from. - num_epochs: int. Number of passes to make over the training set. - use_fake_data: bool. If True, generate a synthetic dataset. - device: string, Either '/cpu:0' or '/gpu:0'. The covariance and inverse - update ops are run on this device. - - Returns: - accuracy of model on the final minibatch of training data. - """ - # Load a dataset. - tf.logging.info("Loading MNIST into memory.") - examples, labels = mnist.load_mnist( - data_dir, - num_epochs=num_epochs, - batch_size=128, - use_fake_data=use_fake_data, - flatten_images=False) - - # Build a ConvNet. - layer_collection = lc.LayerCollection() - loss, accuracy = build_model( - examples, labels, num_labels=10, layer_collection=layer_collection) - - # Fit model. - return minimize_loss_single_machine( - loss, accuracy, layer_collection, device=device) - - -def train_mnist_multitower(data_dir, num_epochs, num_towers, - use_fake_data=True, devices=None): - """Train a ConvNet on MNIST. - - Training data is split equally among the towers. Each tower computes loss on - its own batch of data and the loss is aggregated on the CPU. The model - variables are placed on first tower. The covariance and inverse update ops - and variables are placed on GPUs in a round robin manner. - - Args: - data_dir: string. Directory to read MNIST examples from. - num_epochs: int. Number of passes to make over the training set. - num_towers: int. Number of CPUs to split inference across. - use_fake_data: bool. If True, generate a synthetic dataset. - devices: string, Either list of CPU or GPU. The covariance and inverse - update ops are run on this device. - - Returns: - accuracy of model on the final minibatch of training data. - """ - if devices: - device_count = {"GPU": num_towers} - else: - device_count = {"CPU": num_towers} - - devices = devices or [ - "/cpu:{}".format(tower_id) for tower_id in range(num_towers) - ] - # Load a dataset. - tf.logging.info("Loading MNIST into memory.") - tower_batch_size = 128 - batch_size = tower_batch_size * num_towers - tf.logging.info( - ("Loading MNIST into memory. Using batch_size = %d = %d towers * %d " - "tower batch size.") % (batch_size, num_towers, tower_batch_size)) - examples, labels = mnist.load_mnist( - data_dir, - num_epochs=num_epochs, - batch_size=batch_size, - use_fake_data=use_fake_data, - flatten_images=False) - - # Split minibatch across towers. - examples = tf.split(examples, num_towers) - labels = tf.split(labels, num_towers) - - # Build an MLP. Each tower's layers will be added to the LayerCollection. - layer_collection = lc.LayerCollection() - tower_results = [] - for tower_id in range(num_towers): - with tf.device(devices[tower_id]): - with tf.name_scope("tower%d" % tower_id): - with tf.variable_scope(tf.get_variable_scope(), reuse=(tower_id > 0)): - tf.logging.info("Building tower %d." % tower_id) - tower_results.append( - build_model(examples[tower_id], labels[tower_id], 10, - layer_collection)) - losses, accuracies = zip(*tower_results) - - # Average across towers. - loss = tf.reduce_mean(losses) - accuracy = tf.reduce_mean(accuracies) - - # Fit model. - - session_config = tf.ConfigProto( - allow_soft_placement=False, - device_count=device_count, - ) - - g_step = tf.train.get_or_create_global_step() - optimizer = opt.KfacOptimizer( - learning_rate=0.0001, - cov_ema_decay=0.95, - damping=0.001, - layer_collection=layer_collection, - placement_strategy="round_robin", - cov_devices=devices, - inv_devices=devices, - momentum=0.9) - (cov_update_thunks, - inv_update_thunks) = optimizer.make_vars_and_create_op_thunks() - - def make_update_op(update_thunks): - update_ops = [thunk() for thunk in update_thunks] - return tf.group(*update_ops) - - cov_update_op = make_update_op(cov_update_thunks) - with tf.control_dependencies([cov_update_op]): - inverse_op = tf.cond( - tf.equal(tf.mod(g_step, _INVERT_EVERY), 0), - lambda: make_update_op(inv_update_thunks), tf.no_op) - with tf.control_dependencies([inverse_op]): - train_op = optimizer.minimize(loss, global_step=g_step) - - tf.logging.info("Starting training.") - with tf.train.MonitoredTrainingSession(config=session_config) as sess: - while not sess.should_stop(): - global_step_, loss_, accuracy_, _ = sess.run( - [g_step, loss, accuracy, train_op]) - - if global_step_ % _INVERT_EVERY == 0: - tf.logging.info("global_step: %d | loss: %f | accuracy: %s", - global_step_, loss_, accuracy_) - - -def train_mnist_distributed_sync_replicas(task_id, - is_chief, - num_worker_tasks, - num_ps_tasks, - master, - data_dir, - num_epochs, - op_strategy, - use_fake_data=False): - """Train a ConvNet on MNIST using Sync replicas optimizer. - - Args: - task_id: int. Integer in [0, num_worker_tasks). ID for this worker. - is_chief: `boolean`, `True` if the worker is chief worker. - num_worker_tasks: int. Number of workers in this distributed training setup. - num_ps_tasks: int. Number of parameter servers holding variables. - master: string. IP and port of TensorFlow runtime process. - data_dir: string. Directory to read MNIST examples from. - num_epochs: int. Number of passes to make over the training set. - op_strategy: `string`, Strategy to run the covariance and inverse - ops. If op_strategy == `chief_worker` then covariance and inverse - update ops are run on chief worker otherwise they are run on dedicated - workers. - - use_fake_data: bool. If True, generate a synthetic dataset. - - Returns: - accuracy of model on the final minibatch of training data. - - Raises: - ValueError: If `op_strategy` not in ["chief_worker", "dedicated_workers"]. - """ - # Load a dataset. - tf.logging.info("Loading MNIST into memory.") - examples, labels = mnist.load_mnist( - data_dir, - num_epochs=num_epochs, - batch_size=128, - use_fake_data=use_fake_data, - flatten_images=False) - - # Build a ConvNet. - layer_collection = lc.LayerCollection() - with tf.device(tf.train.replica_device_setter(num_ps_tasks)): - loss, accuracy = build_model( - examples, labels, num_labels=10, layer_collection=layer_collection) - - # Fit model. - checkpoint_dir = None if data_dir is None else os.path.join(data_dir, "kfac") - if op_strategy == "chief_worker": - return distributed_grads_only_and_ops_chief_worker( - task_id, is_chief, num_worker_tasks, num_ps_tasks, master, - checkpoint_dir, loss, accuracy, layer_collection) - elif op_strategy == "dedicated_workers": - return distributed_grads_and_ops_dedicated_workers( - task_id, is_chief, num_worker_tasks, num_ps_tasks, master, - checkpoint_dir, loss, accuracy, layer_collection) - else: - raise ValueError("Only supported op strategies are : {}, {}".format( - "chief_worker", "dedicated_workers")) - - -if __name__ == "__main__": - tf.app.run() diff --git a/tensorflow/contrib/kfac/examples/convnet_mnist_distributed_main.py b/tensorflow/contrib/kfac/examples/convnet_mnist_distributed_main.py deleted file mode 100644 index b4c2d4a9e9..0000000000 --- a/tensorflow/contrib/kfac/examples/convnet_mnist_distributed_main.py +++ /dev/null @@ -1,62 +0,0 @@ -# Copyright 2017 The TensorFlow Authors. All Rights Reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -# ============================================================================== -r"""Train a ConvNet on MNIST using K-FAC. - -Distributed training with sync replicas optimizer. See -`convnet.train_mnist_distributed_sync_replicas` for details. -""" - -from __future__ import absolute_import -from __future__ import division -from __future__ import print_function - - -from absl import flags -import tensorflow as tf - -from tensorflow.contrib.kfac.examples import convnet - -FLAGS = flags.FLAGS -flags.DEFINE_integer("task", -1, "Task identifier") -flags.DEFINE_string("data_dir", "/tmp/mnist", "local mnist dir") -flags.DEFINE_string( - "cov_inv_op_strategy", "chief_worker", - "In dist training mode run the cov, inv ops on chief or dedicated workers." -) -flags.DEFINE_string("master", "local", "Session master.") -flags.DEFINE_integer("ps_tasks", 2, - "Number of tasks in the parameter server job.") -flags.DEFINE_integer("replicas_to_aggregate", 5, - "Number of replicas to aggregate.") -flags.DEFINE_integer("worker_replicas", 5, "Number of replicas in worker job.") -flags.DEFINE_integer("num_epochs", None, "Number of epochs.") - - -def _is_chief(): - """Determines whether a job is the chief worker.""" - if "chief_worker" in FLAGS.brain_jobs: - return FLAGS.brain_job_name == "chief_worker" - else: - return FLAGS.task == 0 - - -def main(unused_argv): - _ = unused_argv - convnet.train_mnist_distributed_sync_replicas( - FLAGS.task, _is_chief(), FLAGS.worker_replicas, FLAGS.ps_tasks, - FLAGS.master, FLAGS.data_dir, FLAGS.num_epochs, FLAGS.cov_inv_op_strategy) - -if __name__ == "__main__": - tf.app.run(main=main) diff --git a/tensorflow/contrib/kfac/examples/convnet_mnist_multi_tower_main.py b/tensorflow/contrib/kfac/examples/convnet_mnist_multi_tower_main.py deleted file mode 100644 index 4249bf8a8d..0000000000 --- a/tensorflow/contrib/kfac/examples/convnet_mnist_multi_tower_main.py +++ /dev/null @@ -1,48 +0,0 @@ -# Copyright 2017 The TensorFlow Authors. All Rights Reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -# ============================================================================== -r"""Train a ConvNet on MNIST using K-FAC. - -Multi tower training mode. See `convnet.train_mnist_multitower` for details. -""" - -from __future__ import absolute_import -from __future__ import division -from __future__ import print_function - - -from absl import flags -import tensorflow as tf - -from tensorflow.contrib.kfac.examples import convnet - -FLAGS = flags.FLAGS -flags.DEFINE_string("data_dir", "/tmp/multitower_1/mnist", "local mnist dir") -flags.DEFINE_integer("num_towers", 2, - "Number of towers for multi tower training.") - - -def main(unused_argv): - _ = unused_argv - assert FLAGS.num_towers > 1 - devices = ["/gpu:{}".format(tower_id) for tower_id in range(FLAGS.num_towers)] - convnet.train_mnist_multitower( - FLAGS.data_dir, - num_epochs=200, - num_towers=FLAGS.num_towers, - devices=devices) - - -if __name__ == "__main__": - tf.app.run(main=main) diff --git a/tensorflow/contrib/kfac/examples/convnet_mnist_single_main.py b/tensorflow/contrib/kfac/examples/convnet_mnist_single_main.py deleted file mode 100644 index 2c1f099360..0000000000 --- a/tensorflow/contrib/kfac/examples/convnet_mnist_single_main.py +++ /dev/null @@ -1,39 +0,0 @@ -# Copyright 2017 The TensorFlow Authors. All Rights Reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -# ============================================================================== -r"""Train a ConvNet on MNIST using K-FAC. - -Train on single machine. See `convnet.train_mnist_single_machine` for details. -""" - -from __future__ import absolute_import -from __future__ import division -from __future__ import print_function - - -from absl import flags -import tensorflow as tf - -from tensorflow.contrib.kfac.examples import convnet - -FLAGS = flags.FLAGS -flags.DEFINE_string("data_dir", "/tmp/mnist", "local mnist dir") - - -def main(unused_argv): - convnet.train_mnist_single_machine(FLAGS.data_dir, num_epochs=200) - - -if __name__ == "__main__": - tf.app.run(main=main) diff --git a/tensorflow/contrib/kfac/examples/mlp.py b/tensorflow/contrib/kfac/examples/mlp.py deleted file mode 100644 index ea2b252a05..0000000000 --- a/tensorflow/contrib/kfac/examples/mlp.py +++ /dev/null @@ -1,354 +0,0 @@ -# Copyright 2017 The TensorFlow Authors. All Rights Reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -# ============================================================================== -r"""Train an MLP on MNIST using K-FAC. - -This library fits a 3-layer, tanh-activated MLP on MNIST using K-FAC. After -~25k steps, this should reach perfect accuracy on the training set. -""" - -from __future__ import absolute_import -from __future__ import division -from __future__ import print_function - -import tensorflow as tf - -from tensorflow.contrib.kfac.examples import mnist - -lc = tf.contrib.kfac.layer_collection -opt = tf.contrib.kfac.optimizer - -__all__ = [ - "fc_layer", - "train_mnist", - "train_mnist_multitower", -] - - -def fc_layer(layer_id, inputs, output_size): - """Builds a fully connected layer. - - Args: - layer_id: int. Integer ID for this layer's variables. - inputs: Tensor of shape [num_examples, input_size]. Each row corresponds - to a single example. - output_size: int. Number of output dimensions after fully connected layer. - - Returns: - preactivations: Tensor of shape [num_examples, output_size]. Values of the - layer immediately before the activation function. - activations: Tensor of shape [num_examples, output_size]. Values of the - layer immediately after the activation function. - params: Tuple of (weights, bias), parameters for this layer. - """ - # TODO(b/67004004): Delete this function and rely on tf.layers exclusively. - layer = tf.layers.Dense( - output_size, - kernel_initializer=tf.random_normal_initializer(), - name="fc_%d" % layer_id) - preactivations = layer(inputs) - activations = tf.nn.tanh(preactivations) - - # layer.weights is a list. This converts it a (hashable) tuple. - return preactivations, activations, (layer.kernel, layer.bias) - - -def build_model(examples, labels, num_labels, layer_collection): - """Builds an MLP classification model. - - Args: - examples: Tensor of shape [num_examples, num_features]. Represents inputs of - model. - labels: Tensor of shape [num_examples]. Contains integer IDs to be predicted - by softmax for each example. - num_labels: int. Number of distinct values 'labels' can take on. - layer_collection: LayerCollection instance describing model architecture. - - Returns: - loss: 0-D Tensor representing loss to be minimized. - accuracy: 0-D Tensor representing model's accuracy. - """ - # Build an MLP. For each layer, we'll keep track of the preactivations, - # activations, weights, and bias. - pre0, act0, params0 = fc_layer(layer_id=0, inputs=examples, output_size=128) - pre1, act1, params1 = fc_layer(layer_id=1, inputs=act0, output_size=64) - pre2, act2, params2 = fc_layer(layer_id=2, inputs=act1, output_size=32) - logits, _, params3 = fc_layer(layer_id=3, inputs=act2, output_size=num_labels) - loss = tf.reduce_mean( - tf.nn.sparse_softmax_cross_entropy_with_logits( - labels=labels, logits=logits)) - accuracy = tf.reduce_mean( - tf.cast(tf.equal(labels, tf.argmax(logits, axis=1)), dtype=tf.float32)) - - # Register parameters. K-FAC needs to know about the inputs, outputs, and - # parameters of each layer and the logits powering the posterior probability - # over classes. - tf.logging.info("Building LayerCollection.") - layer_collection.register_fully_connected(params0, examples, pre0) - layer_collection.register_fully_connected(params1, act0, pre1) - layer_collection.register_fully_connected(params2, act1, pre2) - layer_collection.register_fully_connected(params3, act2, logits) - layer_collection.register_categorical_predictive_distribution( - logits, name="logits") - - return loss, accuracy - - -def minimize(loss, accuracy, layer_collection, num_towers, session_config=None): - """Minimize 'loss' with KfacOptimizer. - - Args: - loss: 0-D Tensor. Loss to be minimized. - accuracy: 0-D Tensor. Accuracy of classifier on current minibatch. - layer_collection: LayerCollection instance. Describes layers in model. - num_towers: int. Number of CPUs to split minibatch across. - session_config: tf.ConfigProto. Configuration for tf.Session(). - - Returns: - accuracy of classifier on final minibatch. - """ - devices = tuple("/cpu:%d" % tower_id for tower_id in range(num_towers)) - - # Train with K-FAC. We'll use a decreasing learning rate that's cut in 1/2 - # every 10k iterations. - tf.logging.info("Building KFAC Optimizer.") - global_step = tf.train.get_or_create_global_step() - optimizer = opt.KfacOptimizer( - learning_rate=tf.train.exponential_decay( - 0.00002, global_step, 10000, 0.5, staircase=True), - cov_ema_decay=0.95, - damping=0.0005, - layer_collection=layer_collection, - momentum=0.99, - placement_strategy="round_robin", - cov_devices=devices, - inv_devices=devices) - - (cov_update_thunks, - inv_update_thunks) = optimizer.make_vars_and_create_op_thunks() - - def make_update_op(update_thunks): - update_ops = [thunk() for thunk in update_thunks] - return tf.group(*update_ops) - - # TODO(b/78537047): change (some) examples to use PeriodicInvCovUpdateKfacOpt - # once that gets moved over? Could still leave more advanced examples as they - # are (e.g. train_mnist_estimator in this file) - - cov_update_op = make_update_op(cov_update_thunks) - with tf.control_dependencies([cov_update_op]): - # We update the inverses only every 20 iterations. - inverse_op = tf.cond( - tf.equal(tf.mod(global_step, 100), 0), - lambda: make_update_op(inv_update_thunks), tf.no_op) - with tf.control_dependencies([inverse_op]): - train_op = optimizer.minimize(loss, global_step=global_step) - - tf.logging.info("Starting training.") - with tf.train.MonitoredTrainingSession(config=session_config) as sess: - while not sess.should_stop(): - global_step_, loss_, accuracy_, _ = sess.run( - [global_step, loss, accuracy, train_op]) - - if global_step_ % 100 == 0: - tf.logging.info("global_step: %d | loss: %f | accuracy: %f", - global_step_, loss_, accuracy_) - - return accuracy_ - - -def train_mnist(data_dir, num_epochs, use_fake_data=False): - """Train an MLP on MNIST. - - Args: - data_dir: string. Directory to read MNIST examples from. - num_epochs: int. Number of passes to make over the training set. - use_fake_data: bool. If True, generate a synthetic dataset. - - Returns: - accuracy of model on the final minibatch of training data. - """ - # Load a dataset. - tf.logging.info("Loading MNIST into memory.") - examples, labels = mnist.load_mnist( - data_dir, - num_epochs=num_epochs, - batch_size=64, - flatten_images=True, - use_fake_data=use_fake_data) - - # Build an MLP. The model's layers will be added to the LayerCollection. - tf.logging.info("Building model.") - layer_collection = lc.LayerCollection() - loss, accuracy = build_model(examples, labels, 10, layer_collection) - - # Fit model. - minimize(loss, accuracy, layer_collection, 1) - - -def train_mnist_multitower(data_dir, - num_epochs, - num_towers, - use_fake_data=False): - """Train an MLP on MNIST, splitting the minibatch across multiple towers. - - Args: - data_dir: string. Directory to read MNIST examples from. - num_epochs: int. Number of passes to make over the training set. - num_towers: int. Number of CPUs to split minibatch across. - use_fake_data: bool. If True, generate a synthetic dataset. - - Returns: - accuracy of model on the final minibatch of training data. - """ - # Load a dataset. - tower_batch_size = 64 - batch_size = tower_batch_size * num_towers - tf.logging.info( - ("Loading MNIST into memory. Using batch_size = %d = %d towers * %d " - "tower batch size.") % (batch_size, num_towers, tower_batch_size)) - examples, labels = mnist.load_mnist( - data_dir, - num_epochs=num_epochs, - batch_size=batch_size, - flatten_images=True, - use_fake_data=use_fake_data) - - # Split minibatch across towers. - examples = tf.split(examples, num_towers) - labels = tf.split(labels, num_towers) - - # Build an MLP. Each tower's layers will be added to the LayerCollection. - layer_collection = lc.LayerCollection() - tower_results = [] - for tower_id in range(num_towers): - with tf.device("/cpu:%d" % tower_id): - with tf.name_scope("tower%d" % tower_id): - with tf.variable_scope(tf.get_variable_scope(), reuse=(tower_id > 0)): - tf.logging.info("Building tower %d." % tower_id) - tower_results.append( - build_model(examples[tower_id], labels[tower_id], 10, - layer_collection)) - losses, accuracies = zip(*tower_results) - - # Average across towers. - loss = tf.reduce_mean(losses) - accuracy = tf.reduce_mean(accuracies) - - # Fit model. - session_config = tf.ConfigProto( - allow_soft_placement=False, device_count={ - "CPU": num_towers - }) - return minimize( - loss, accuracy, layer_collection, num_towers, - session_config=session_config) - - -def train_mnist_estimator(data_dir, num_epochs, use_fake_data=False): - """Train an MLP on MNIST using tf.estimator. - - Args: - data_dir: string. Directory to read MNIST examples from. - num_epochs: int. Number of passes to make over the training set. - use_fake_data: bool. If True, generate a synthetic dataset. - - Returns: - accuracy of model on the final minibatch of training data. - """ - - # Load a dataset. - def input_fn(): - tf.logging.info("Loading MNIST into memory.") - return mnist.load_mnist( - data_dir, - num_epochs=num_epochs, - batch_size=64, - flatten_images=True, - use_fake_data=use_fake_data) - - def model_fn(features, labels, mode, params): - """Model function for MLP trained with K-FAC. - - Args: - features: Tensor of shape [batch_size, input_size]. Input features. - labels: Tensor of shape [batch_size]. Target labels for training. - mode: tf.estimator.ModeKey. Must be TRAIN. - params: ignored. - - Returns: - EstimatorSpec for training. - - Raises: - ValueError: If 'mode' is anything other than TRAIN. - """ - del params - - if mode != tf.estimator.ModeKeys.TRAIN: - raise ValueError("Only training is supposed with this API.") - - # Build a ConvNet. - layer_collection = lc.LayerCollection() - loss, accuracy = build_model( - features, labels, num_labels=10, layer_collection=layer_collection) - - # Train with K-FAC. - global_step = tf.train.get_or_create_global_step() - optimizer = opt.KfacOptimizer( - learning_rate=tf.train.exponential_decay( - 0.00002, global_step, 10000, 0.5, staircase=True), - cov_ema_decay=0.95, - damping=0.0001, - layer_collection=layer_collection, - momentum=0.99) - - (cov_update_thunks, - inv_update_thunks) = optimizer.make_vars_and_create_op_thunks() - - def make_update_op(update_thunks): - update_ops = [thunk() for thunk in update_thunks] - return tf.group(*update_ops) - - def make_batch_executed_op(update_thunks, batch_size=1): - return tf.group(*tf.contrib.kfac.utils.batch_execute( - global_step, update_thunks, batch_size=batch_size)) - - # Run cov_update_op every step. Run 1 inv_update_ops per step. - cov_update_op = make_update_op(cov_update_thunks) - with tf.control_dependencies([cov_update_op]): - # But make sure to execute all the inverse ops on the first step - inverse_op = tf.cond(tf.equal(global_step, 0), - lambda: make_update_op(inv_update_thunks), - lambda: make_batch_executed_op(inv_update_thunks)) - with tf.control_dependencies([inverse_op]): - train_op = optimizer.minimize(loss, global_step=global_step) - - # Print metrics every 5 sec. - hooks = [ - tf.train.LoggingTensorHook( - { - "loss": loss, - "accuracy": accuracy - }, every_n_secs=5), - ] - return tf.estimator.EstimatorSpec( - mode=mode, loss=loss, train_op=train_op, training_hooks=hooks) - - run_config = tf.estimator.RunConfig( - model_dir="/tmp/mnist", save_checkpoints_steps=1, keep_checkpoint_max=100) - - # Train until input_fn() is empty with Estimator. This is a prerequisite for - # TPU compatibility. - estimator = tf.estimator.Estimator(model_fn=model_fn, config=run_config) - estimator.train(input_fn=input_fn) diff --git a/tensorflow/contrib/kfac/examples/mlp_mnist_main.py b/tensorflow/contrib/kfac/examples/mlp_mnist_main.py deleted file mode 100644 index 9c34ade1d2..0000000000 --- a/tensorflow/contrib/kfac/examples/mlp_mnist_main.py +++ /dev/null @@ -1,64 +0,0 @@ -# Copyright 2017 The TensorFlow Authors. All Rights Reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -# ============================================================================== -r"""Train an MLP on MNIST using K-FAC. - -See mlp.py for details. -""" - -from __future__ import absolute_import -from __future__ import division -from __future__ import print_function - -import argparse -import sys - -import tensorflow as tf - -from tensorflow.contrib.kfac.examples import mlp - -FLAGS = None - - -def main(argv): - _ = argv - if FLAGS.use_estimator: - if FLAGS.num_towers != 1: - raise ValueError("Only 1 device supported in tf.estimator example.") - mlp.train_mnist_estimator(FLAGS.data_dir, num_epochs=200) - elif FLAGS.num_towers > 1: - mlp.train_mnist_multitower( - FLAGS.data_dir, num_epochs=200, num_towers=FLAGS.num_towers) - else: - mlp.train_mnist(FLAGS.data_dir, num_epochs=200) - - -if __name__ == "__main__": - parser = argparse.ArgumentParser() - parser.add_argument( - "--data_dir", - type=str, - default="/tmp/mnist", - help="Directory to store dataset in.") - parser.add_argument( - "--num_towers", - type=int, - default=1, - help="Number of CPUs to split minibatch across.") - parser.add_argument( - "--use_estimator", - action="store_true", - help="Use tf.estimator API to train.") - FLAGS, unparsed = parser.parse_known_args() - tf.app.run(main=main, argv=[sys.argv[0]] + unparsed) diff --git a/tensorflow/contrib/kfac/examples/mnist.py b/tensorflow/contrib/kfac/examples/mnist.py deleted file mode 100644 index 547c4ab25d..0000000000 --- a/tensorflow/contrib/kfac/examples/mnist.py +++ /dev/null @@ -1,69 +0,0 @@ -# Copyright 2017 The TensorFlow Authors. All Rights Reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -# ============================================================================== -"""Utilities for loading MNIST into TensorFlow.""" - -from __future__ import absolute_import -from __future__ import division -from __future__ import print_function - -import numpy as np -import tensorflow as tf - -__all__ = [ - 'load_mnist', -] - - -def load_mnist(data_dir, - num_epochs, - batch_size, - flatten_images=True, - use_fake_data=False): - """Loads MNIST dataset into memory. - - Args: - data_dir: string. Directory to read MNIST examples from. - num_epochs: int. Number of passes to make over the dataset. - batch_size: int. Number of examples per minibatch. - flatten_images: bool. If True, [28, 28, 1]-shaped images are flattened into - [784]-shaped vectors. - use_fake_data: bool. If True, generate a synthetic dataset rather than - reading MNIST in. - - Returns: - examples: Tensor of shape [batch_size, 784] if 'flatten_images' is - True, else [batch_size, 28, 28, 1]. Each row is one example. - Values in [0, 1]. - labels: Tensor of shape [batch_size]. Indices of integer corresponding to - each example. Values in {0...9}. - """ - if use_fake_data: - rng = np.random.RandomState(42) - num_examples = batch_size * 4 - images = rng.rand(num_examples, 28 * 28) - if not flatten_images: - images = np.reshape(images, [num_examples, 28, 28, 1]) - labels = rng.randint(10, size=num_examples) - else: - mnist_data = tf.contrib.learn.datasets.mnist.read_data_sets( - data_dir, reshape=flatten_images) - num_examples = len(mnist_data.train.labels) - images = mnist_data.train.images - labels = mnist_data.train.labels - - dataset = tf.data.Dataset.from_tensor_slices((np.asarray( - images, dtype=np.float32), np.asarray(labels, dtype=np.int64))) - return (dataset.repeat(num_epochs).shuffle(num_examples).batch(batch_size) - .make_one_shot_iterator().get_next()) diff --git a/tensorflow/contrib/kfac/examples/tests/BUILD b/tensorflow/contrib/kfac/examples/tests/BUILD deleted file mode 100644 index ede7f183fe..0000000000 --- a/tensorflow/contrib/kfac/examples/tests/BUILD +++ /dev/null @@ -1,52 +0,0 @@ -package(default_visibility = ["//visibility:private"]) - -licenses(["notice"]) # Apache 2.0 - -exports_files(["LICENSE"]) - -load("//tensorflow:tensorflow.bzl", "py_test") - -py_test( - name = "mlp_test", - size = "large", - srcs = ["mlp_test.py"], - srcs_version = "PY2AND3", - tags = [ - "no_pip", - "notsan", - ], - deps = [ - "//tensorflow:tensorflow_py", - "//tensorflow/contrib/kfac/examples:mlp", - "//third_party/py/numpy", - ], -) - -py_test( - name = "convnet_test", - size = "large", - srcs = ["convnet_test.py"], - srcs_version = "PY2AND3", - tags = [ - "no_pip", - "notsan", - ], - deps = [ - "//tensorflow:tensorflow_py", - "//tensorflow/contrib/kfac", - "//tensorflow/contrib/kfac/examples:convnet", - "//third_party/py/numpy", - ], -) - -py_test( - name = "mnist_test", - srcs = ["mnist_test.py"], - srcs_version = "PY2AND3", - tags = ["no_pip"], - deps = [ - "//tensorflow:tensorflow_py", - "//tensorflow/contrib/kfac/examples:mnist", - "//third_party/py/numpy", - ], -) diff --git a/tensorflow/contrib/kfac/examples/tests/convnet_test.py b/tensorflow/contrib/kfac/examples/tests/convnet_test.py deleted file mode 100644 index adecda7166..0000000000 --- a/tensorflow/contrib/kfac/examples/tests/convnet_test.py +++ /dev/null @@ -1,166 +0,0 @@ -# Copyright 2017 The TensorFlow Authors. All Rights Reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -# ============================================================================== -"""Tests for convnet.py.""" - -from __future__ import absolute_import -from __future__ import division -from __future__ import print_function - -import numpy as np -import tensorflow as tf - -from tensorflow.contrib.kfac import layer_collection as lc -from tensorflow.contrib.kfac.examples import convnet - - -class ConvNetTest(tf.test.TestCase): - - def testConvLayer(self): - with tf.Graph().as_default(): - pre, act, (w, b) = convnet.conv_layer( - layer_id=1, - inputs=tf.zeros([5, 3, 3, 2]), - kernel_size=3, - out_channels=5) - self.assertShapeEqual(np.zeros([5, 3, 3, 5]), pre) - self.assertShapeEqual(np.zeros([5, 3, 3, 5]), act) - self.assertShapeEqual(np.zeros([3, 3, 2, 5]), tf.convert_to_tensor(w)) - self.assertShapeEqual(np.zeros([5]), tf.convert_to_tensor(b)) - self.assertIsInstance(w, tf.Variable) - self.assertIsInstance(b, tf.Variable) - self.assertIn("conv_1", w.op.name) - self.assertIn("conv_1", b.op.name) - - def testMaxPoolLayer(self): - with tf.Graph().as_default(): - act = convnet.max_pool_layer( - layer_id=1, inputs=tf.zeros([5, 6, 6, 2]), kernel_size=5, stride=3) - self.assertShapeEqual(np.zeros([5, 2, 2, 2]), act) - self.assertEqual(act.op.name, "pool_1/pool") - - def testLinearLayer(self): - with tf.Graph().as_default(): - act, (w, b) = convnet.linear_layer( - layer_id=1, inputs=tf.zeros([5, 20]), output_size=5) - self.assertShapeEqual(np.zeros([5, 5]), act) - self.assertShapeEqual(np.zeros([20, 5]), tf.convert_to_tensor(w)) - self.assertShapeEqual(np.zeros([5]), tf.convert_to_tensor(b)) - self.assertIsInstance(w, tf.Variable) - self.assertIsInstance(b, tf.Variable) - self.assertIn("fc_1", w.op.name) - self.assertIn("fc_1", b.op.name) - - def testBuildModel(self): - with tf.Graph().as_default(): - x = tf.placeholder(tf.float32, [None, 6, 6, 3]) - y = tf.placeholder(tf.int64, [None]) - layer_collection = lc.LayerCollection() - loss, accuracy = convnet.build_model( - x, y, num_labels=5, layer_collection=layer_collection) - - # Ensure layers and logits were registered. - self.assertEqual(len(layer_collection.fisher_blocks), 3) - self.assertEqual(len(layer_collection.losses), 1) - - # Ensure inference doesn't crash. - with self.test_session() as sess: - sess.run(tf.global_variables_initializer()) - feed_dict = { - x: np.random.randn(10, 6, 6, 3).astype(np.float32), - y: np.random.randint(5, size=10).astype(np.int64), - } - sess.run([loss, accuracy], feed_dict=feed_dict) - - def _build_toy_problem(self): - """Construct a toy linear regression problem. - - Initial loss should be, - 2.5 = 0.5 * (1^2 + 2^2) - - Returns: - loss: 0-D Tensor representing loss to be minimized. - accuracy: 0-D Tensors representing model accuracy. - layer_collection: LayerCollection instance describing model architecture. - """ - x = np.asarray([[1.], [2.]]).astype(np.float32) - y = np.asarray([1., 2.]).astype(np.float32) - x, y = (tf.data.Dataset.from_tensor_slices((x, y)) - .repeat(100).batch(2).make_one_shot_iterator().get_next()) - w = tf.get_variable("w", shape=[1, 1], initializer=tf.zeros_initializer()) - y_hat = tf.matmul(x, w) - loss = tf.reduce_mean(0.5 * tf.square(y_hat - y)) - accuracy = loss - - layer_collection = lc.LayerCollection() - layer_collection.register_fully_connected(params=w, inputs=x, outputs=y_hat) - layer_collection.register_normal_predictive_distribution(y_hat) - - return loss, accuracy, layer_collection - - def testMinimizeLossSingleMachine(self): - with tf.Graph().as_default(): - loss, accuracy, layer_collection = self._build_toy_problem() - accuracy_ = convnet.minimize_loss_single_machine( - loss, accuracy, layer_collection, device="/cpu:0") - self.assertLess(accuracy_, 2.0) - - def testMinimizeLossDistributed(self): - with tf.Graph().as_default(): - loss, accuracy, layer_collection = self._build_toy_problem() - accuracy_ = convnet.distributed_grads_only_and_ops_chief_worker( - task_id=0, - is_chief=True, - num_worker_tasks=1, - num_ps_tasks=0, - master="", - checkpoint_dir=None, - loss=loss, - accuracy=accuracy, - layer_collection=layer_collection) - self.assertLess(accuracy_, 2.0) - - def testTrainMnistSingleMachine(self): - with tf.Graph().as_default(): - # Ensure model training doesn't crash. - # - # Ideally, we should check that accuracy increases as the model converges, - # but there are too few parameters for the model to effectively memorize - # the training set the way an MLP can. - convnet.train_mnist_single_machine( - data_dir=None, num_epochs=1, use_fake_data=True, device="/cpu:0") - - def testTrainMnistMultitower(self): - with tf.Graph().as_default(): - # Ensure model training doesn't crash. - convnet.train_mnist_multitower( - data_dir=None, num_epochs=1, num_towers=2, use_fake_data=True) - - def testTrainMnistDistributed(self): - with tf.Graph().as_default(): - # Ensure model training doesn't crash. - convnet.train_mnist_distributed_sync_replicas( - task_id=0, - is_chief=True, - num_worker_tasks=1, - num_ps_tasks=0, - master="", - data_dir=None, - num_epochs=2, - op_strategy="chief_worker", - use_fake_data=True) - - -if __name__ == "__main__": - tf.test.main() diff --git a/tensorflow/contrib/kfac/examples/tests/mlp_test.py b/tensorflow/contrib/kfac/examples/tests/mlp_test.py deleted file mode 100644 index 22da6c29f1..0000000000 --- a/tensorflow/contrib/kfac/examples/tests/mlp_test.py +++ /dev/null @@ -1,63 +0,0 @@ -# Copyright 2017 The TensorFlow Authors. All Rights Reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -# ============================================================================== -"""Tests for mlp.py.""" - -from __future__ import absolute_import -from __future__ import division -from __future__ import print_function - -import numpy as np -import tensorflow as tf - -from tensorflow.contrib.kfac.examples import mlp - - -class MlpTest(tf.test.TestCase): - - def testFcLayer(self): - with tf.Graph().as_default(): - pre, act, (w, b) = mlp.fc_layer( - layer_id=1, inputs=tf.zeros([5, 3]), output_size=10) - self.assertShapeEqual(np.zeros([5, 10]), pre) - self.assertShapeEqual(np.zeros([5, 10]), act) - self.assertShapeEqual(np.zeros([3, 10]), tf.convert_to_tensor(w)) - self.assertShapeEqual(np.zeros([10]), tf.convert_to_tensor(b)) - self.assertIsInstance(w, tf.Variable) - self.assertIsInstance(b, tf.Variable) - self.assertIn("fc_1/", w.op.name) - self.assertIn("fc_1/", b.op.name) - - def testTrainMnist(self): - with tf.Graph().as_default(): - # Ensure model training doesn't crash. - # - # Ideally, we should check that accuracy increases as the model converges, - # but that takes a non-trivial amount of compute. - mlp.train_mnist(data_dir=None, num_epochs=1, use_fake_data=True) - - def testTrainMnistMultitower(self): - with tf.Graph().as_default(): - # Ensure model training doesn't crash. - mlp.train_mnist_multitower( - data_dir=None, num_epochs=1, num_towers=2, use_fake_data=True) - - def testTrainMnistEstimator(self): - with tf.Graph().as_default(): - # Ensure model training doesn't crash. - mlp.train_mnist_estimator(data_dir=None, num_epochs=1, use_fake_data=True) - - -if __name__ == "__main__": - tf.test.main() diff --git a/tensorflow/contrib/kfac/examples/tests/mnist_test.py b/tensorflow/contrib/kfac/examples/tests/mnist_test.py deleted file mode 100644 index 92f8462357..0000000000 --- a/tensorflow/contrib/kfac/examples/tests/mnist_test.py +++ /dev/null @@ -1,72 +0,0 @@ -# Copyright 2017 The TensorFlow Authors. All Rights Reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -# ============================================================================== -"""Tests for mnist.py.""" - -from __future__ import absolute_import -from __future__ import division -from __future__ import print_function - -import numpy as np -import tensorflow as tf - -from tensorflow.contrib.kfac.examples import mnist - - -class MnistTest(tf.test.TestCase): - - def testValues(self): - """Ensure values are in their expected range.""" - with tf.Graph().as_default(): - examples, labels = mnist.load_mnist( - data_dir=None, num_epochs=1, batch_size=64, use_fake_data=True) - - with self.test_session() as sess: - examples_, labels_ = sess.run([examples, labels]) - self.assertTrue(np.all((0 <= examples_) & (examples_ < 1))) - self.assertTrue(np.all((0 <= labels_) & (labels_ < 10))) - - def testFlattenedShapes(self): - """Ensure images are flattened into their appropriate shape.""" - with tf.Graph().as_default(): - examples, labels = mnist.load_mnist( - data_dir=None, - num_epochs=1, - batch_size=64, - flatten_images=True, - use_fake_data=True) - - with self.test_session() as sess: - examples_, labels_ = sess.run([examples, labels]) - self.assertEqual(examples_.shape, (64, 784)) - self.assertEqual(labels_.shape, (64,)) - - def testNotFlattenedShapes(self): - """Ensure non-flattened images are their appropriate shape.""" - with tf.Graph().as_default(): - examples, labels = mnist.load_mnist( - data_dir=None, - num_epochs=1, - batch_size=64, - flatten_images=False, - use_fake_data=True) - - with self.test_session() as sess: - examples_, labels_ = sess.run([examples, labels]) - self.assertEqual(examples_.shape, (64, 28, 28, 1)) - self.assertEqual(labels_.shape, (64,)) - - -if __name__ == '__main__': - tf.test.main() diff --git a/tensorflow/contrib/kfac/g3doc/autoencoder.png b/tensorflow/contrib/kfac/g3doc/autoencoder.png Binary files differdeleted file mode 100644 index 20f93c7703..0000000000 --- a/tensorflow/contrib/kfac/g3doc/autoencoder.png +++ /dev/null diff --git a/tensorflow/contrib/kfac/python/kernel_tests/BUILD b/tensorflow/contrib/kfac/python/kernel_tests/BUILD deleted file mode 100644 index 6e4a8d71ba..0000000000 --- a/tensorflow/contrib/kfac/python/kernel_tests/BUILD +++ /dev/null @@ -1,160 +0,0 @@ -package(default_visibility = ["//visibility:private"]) - -licenses(["notice"]) # Apache 2.0 - -exports_files(["LICENSE"]) - -load("//tensorflow:tensorflow.bzl", "py_test") - -py_test( - name = "estimator_test", - srcs = ["estimator_test.py"], - srcs_version = "PY2AND3", - deps = [ - "//tensorflow/contrib/kfac/python/ops:fisher_estimator", - "//tensorflow/contrib/kfac/python/ops:layer_collection", - "//tensorflow/contrib/kfac/python/ops:utils", - "//tensorflow/python:array_ops", - "//tensorflow/python:client_testlib", - "//tensorflow/python:constant_op", - "//tensorflow/python:control_flow_ops", - "//tensorflow/python:dtypes", - "//tensorflow/python:framework_ops", - "//tensorflow/python:init_ops", - "//tensorflow/python:linalg_ops", - "//tensorflow/python:math_ops", - "//tensorflow/python:random_ops", - "//tensorflow/python:training", - "//tensorflow/python:variable_scope", - "//tensorflow/python:variables", - "//third_party/py/numpy", - ], -) - -py_test( - name = "fisher_factors_test", - srcs = ["fisher_factors_test.py"], - srcs_version = "PY2AND3", - deps = [ - "//tensorflow/contrib/kfac/python/ops:fisher_blocks", - "//tensorflow/contrib/kfac/python/ops:fisher_factors", - "//tensorflow/python:array_ops", - "//tensorflow/python:client_testlib", - "//tensorflow/python:constant_op", - "//tensorflow/python:dtypes", - "//tensorflow/python:framework_ops", - "//tensorflow/python:gradients", - "//tensorflow/python:math_ops", - "//tensorflow/python:random_seed", - "//tensorflow/python:variables", - "//third_party/py/numpy", - ], -) - -py_test( - name = "fisher_blocks_test", - srcs = ["fisher_blocks_test.py"], - srcs_version = "PY2AND3", - deps = [ - "//tensorflow/contrib/kfac/python/ops:fisher_blocks", - "//tensorflow/contrib/kfac/python/ops:layer_collection", - "//tensorflow/contrib/kfac/python/ops:linear_operator", - "//tensorflow/contrib/kfac/python/ops:utils", - "//tensorflow/python:array_ops", - "//tensorflow/python:client_testlib", - "//tensorflow/python:framework_ops", - "//tensorflow/python:math_ops", - "//tensorflow/python:random_ops", - "//tensorflow/python:random_seed", - "//tensorflow/python:state_ops", - "//tensorflow/python:variables", - "//third_party/py/numpy", - ], -) - -py_test( - name = "layer_collection_test", - srcs = ["layer_collection_test.py"], - srcs_version = "PY2AND3", - deps = [ - "//tensorflow/contrib/kfac/python/ops:fisher_blocks", - "//tensorflow/contrib/kfac/python/ops:fisher_factors", - "//tensorflow/contrib/kfac/python/ops:layer_collection", - "//tensorflow/python:array_ops", - "//tensorflow/python:client_testlib", - "//tensorflow/python:dtypes", - "//tensorflow/python:framework_ops", - "//tensorflow/python:linalg_ops", - "//tensorflow/python:math_ops", - "//tensorflow/python:random_ops", - "//tensorflow/python:random_seed", - "//tensorflow/python:variable_scope", - ], -) - -py_test( - name = "optimizer_test", - srcs = ["optimizer_test.py"], - srcs_version = "PY2AND3", - deps = [ - "//tensorflow/contrib/kfac/python/ops:fisher_factors", - "//tensorflow/contrib/kfac/python/ops:kfac_optimizer", - "//tensorflow/contrib/kfac/python/ops:layer_collection", - "//tensorflow/python:array_ops", - "//tensorflow/python:client_testlib", - "//tensorflow/python:framework_ops", - "//tensorflow/python:init_ops", - "//tensorflow/python:math_ops", - "//tensorflow/python:nn", - "//tensorflow/python:variable_scope", - "//tensorflow/python:variables", - "//third_party/py/numpy", - ], -) - -py_test( - name = "utils_test", - srcs = ["utils_test.py"], - srcs_version = "PY2AND3", - tags = ["no_windows"], # TODO: needs investigation on Windows - deps = [ - "//tensorflow/contrib/kfac/python/ops:utils", - "//tensorflow/contrib/tpu", - "//tensorflow/python:array_ops", - "//tensorflow/python:client_testlib", - "//tensorflow/python:dtypes", - "//tensorflow/python:framework_ops", - "//tensorflow/python:linalg_ops", - "//tensorflow/python:random_seed", - "//tensorflow/python:variable_scope", - "//tensorflow/python:variables", - "//third_party/py/numpy", - ], -) - -py_test( - name = "op_queue_test", - srcs = ["op_queue_test.py"], - srcs_version = "PY2AND3", - deps = [ - "//tensorflow/contrib/kfac/python/ops:op_queue", - "//tensorflow/python:client_testlib", - "//tensorflow/python:framework_ops", - "//tensorflow/python:math_ops", - ], -) - -py_test( - name = "loss_functions_test", - srcs = ["loss_functions_test.py"], - srcs_version = "PY2AND3", - deps = [ - "//tensorflow/contrib/kfac/python/ops:loss_functions", - "//tensorflow/python:array_ops", - "//tensorflow/python:client_testlib", - "//tensorflow/python:constant_op", - "//tensorflow/python:framework_ops", - "//tensorflow/python:random_ops", - "//third_party/py/numpy", - ], -) diff --git a/tensorflow/contrib/kfac/python/kernel_tests/estimator_test.py b/tensorflow/contrib/kfac/python/kernel_tests/estimator_test.py deleted file mode 100644 index 76b31a5730..0000000000 --- a/tensorflow/contrib/kfac/python/kernel_tests/estimator_test.py +++ /dev/null @@ -1,310 +0,0 @@ -# Copyright 2017 The TensorFlow Authors. All Rights Reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -# ============================================================================== -"""Tests for tf.contrib.kfac.estimator.""" - -from __future__ import absolute_import -from __future__ import division -from __future__ import print_function - -import numpy as np - -from tensorflow.contrib.kfac.python.ops import estimator -from tensorflow.contrib.kfac.python.ops import layer_collection as lc -from tensorflow.contrib.kfac.python.ops import utils -from tensorflow.python.framework import dtypes -from tensorflow.python.framework import ops -from tensorflow.python.ops import array_ops -from tensorflow.python.ops import control_flow_ops -from tensorflow.python.ops import init_ops -from tensorflow.python.ops import linalg_ops -from tensorflow.python.ops import math_ops -from tensorflow.python.ops import random_ops -from tensorflow.python.ops import variable_scope -from tensorflow.python.ops import variables -from tensorflow.python.platform import test -from tensorflow.python.training import training_util - -_ALL_ESTIMATION_MODES = ["gradients", "empirical", "curvature_prop", "exact"] - - -class EstimatorTest(test.TestCase): - - def setUp(self): - self._graph = ops.Graph() - with self._graph.as_default(): - self.layer_collection = lc.LayerCollection() - - self.inputs = random_ops.random_normal((2, 2), dtype=dtypes.float32) - self.weights = variable_scope.get_variable( - "w", shape=(2, 2), dtype=dtypes.float32) - self.bias = variable_scope.get_variable( - "b", initializer=init_ops.zeros_initializer(), shape=(2, 1)) - self.output = math_ops.matmul(self.inputs, self.weights) + self.bias - - # Only register the weights. - self.layer_collection.register_fully_connected( - params=(self.weights,), inputs=self.inputs, outputs=self.output) - - self.outputs = math_ops.tanh(self.output) - self.targets = array_ops.zeros_like(self.outputs) - self.layer_collection.register_categorical_predictive_distribution( - logits=self.outputs, targets=self.targets) - - def testEstimatorInitManualRegistration(self): - with self._graph.as_default(): - # We should be able to build an estimator for only the registered vars. - estimator.FisherEstimatorRoundRobin( - variables=[self.weights], - cov_ema_decay=0.1, - damping=0.2, - layer_collection=self.layer_collection - ) - - # Check that we throw an error if we try to build an estimator for vars - # that were not manually registered. - with self.assertRaises(ValueError): - est = estimator.FisherEstimatorRoundRobin( - variables=[self.weights, self.bias], - cov_ema_decay=0.1, - damping=0.2, - layer_collection=self.layer_collection - ) - est.make_vars_and_create_op_thunks() - - # Check that we throw an error if we don't include registered variables, - # i.e. self.weights - with self.assertRaises(ValueError): - est = estimator.FisherEstimatorRoundRobin( - variables=[], - cov_ema_decay=0.1, - damping=0.2, - layer_collection=self.layer_collection) - est.make_vars_and_create_op_thunks() - - @test.mock.patch.object(utils.SubGraph, "variable_uses", return_value=42) - def testVariableWrongNumberOfUses(self, mock_uses): - with self.assertRaises(ValueError): - est = estimator.FisherEstimatorRoundRobin( - variables=[self.weights], - cov_ema_decay=0.1, - damping=0.2, - layer_collection=self.layer_collection) - est.make_vars_and_create_op_thunks() - - def testInvalidEstimationMode(self): - with self.assertRaises(ValueError): - est = estimator.FisherEstimatorRoundRobin( - variables=[self.weights], - cov_ema_decay=0.1, - damping=0.2, - layer_collection=self.layer_collection, - estimation_mode="not_a_real_mode") - est.make_vars_and_create_op_thunks() - - def testGradientsModeBuild(self): - with self._graph.as_default(): - est = estimator.FisherEstimatorRoundRobin( - variables=[self.weights], - cov_ema_decay=0.1, - damping=0.2, - layer_collection=self.layer_collection, - estimation_mode="gradients") - est.make_vars_and_create_op_thunks() - - def testEmpiricalModeBuild(self): - with self._graph.as_default(): - est = estimator.FisherEstimatorRoundRobin( - variables=[self.weights], - cov_ema_decay=0.1, - damping=0.2, - layer_collection=self.layer_collection, - estimation_mode="empirical") - est.make_vars_and_create_op_thunks() - - def testCurvaturePropModeBuild(self): - with self._graph.as_default(): - est = estimator.FisherEstimatorRoundRobin( - variables=[self.weights], - cov_ema_decay=0.1, - damping=0.2, - layer_collection=self.layer_collection, - estimation_mode="curvature_prop") - est.make_vars_and_create_op_thunks() - - def testExactModeBuild(self): - with self._graph.as_default(): - est = estimator.FisherEstimatorRoundRobin( - variables=[self.weights], - cov_ema_decay=0.1, - damping=0.2, - layer_collection=self.layer_collection, - estimation_mode="exact") - est.make_vars_and_create_op_thunks() - - def test_cov_update_thunks(self): - """Ensures covariance update ops run once per global_step.""" - with self._graph.as_default(), self.cached_session() as sess: - fisher_estimator = estimator.FisherEstimatorRoundRobin( - variables=[self.weights], - layer_collection=self.layer_collection, - damping=0.2, - cov_ema_decay=0.0) - - # Construct an op that executes one covariance update per step. - global_step = training_util.get_or_create_global_step() - (cov_variable_thunks, cov_update_op_thunks, _, - _) = fisher_estimator.create_ops_and_vars_thunks() - for thunk in cov_variable_thunks: - thunk() - cov_matrices = [ - fisher_factor.get_cov() - for fisher_factor in self.layer_collection.get_factors() - ] - cov_update_op = control_flow_ops.case( - [(math_ops.equal(global_step, i), thunk) - for i, thunk in enumerate(cov_update_op_thunks)]) - increment_global_step = global_step.assign_add(1) - - sess.run(variables.global_variables_initializer()) - initial_cov_values = sess.run(cov_matrices) - - # Ensure there's one update per covariance matrix. - self.assertEqual(len(cov_matrices), len(cov_update_op_thunks)) - - # Test is no-op if only 1 covariance matrix. - assert len(cov_matrices) > 1 - - for i in range(len(cov_matrices)): - # Compare new and old covariance values - new_cov_values = sess.run(cov_matrices) - is_cov_equal = [ - np.allclose(initial_cov_value, new_cov_value) - for (initial_cov_value, - new_cov_value) in zip(initial_cov_values, new_cov_values) - ] - num_cov_equal = sum(is_cov_equal) - - # Ensure exactly one covariance matrix changes per step. - self.assertEqual(num_cov_equal, len(cov_matrices) - i) - - # Run all covariance update ops. - sess.run(cov_update_op) - sess.run(increment_global_step) - - def test_round_robin_placement(self): - """Check if the ops and variables are placed on devices correctly.""" - with self._graph.as_default(): - fisher_estimator = estimator.FisherEstimatorRoundRobin( - variables=[self.weights], - layer_collection=self.layer_collection, - damping=0.2, - cov_ema_decay=0.0, - cov_devices=["/cpu:{}".format(i) for i in range(2)], - inv_devices=["/cpu:{}".format(i) for i in range(2)]) - - # Construct an op that executes one covariance update per step. - (cov_update_thunks, - inv_update_thunks) = fisher_estimator.make_vars_and_create_op_thunks( - scope="test") - cov_update_ops = tuple(thunk() for thunk in cov_update_thunks) - inv_update_ops = tuple(thunk() for thunk in inv_update_thunks) - self.assertEqual(cov_update_ops[0].device, "/device:CPU:0") - self.assertEqual(cov_update_ops[1].device, "/device:CPU:1") - self.assertEqual(inv_update_ops[0].device, "/device:CPU:0") - self.assertEqual(inv_update_ops[1].device, "/device:CPU:1") - cov_matrices = [ - fisher_factor.get_cov() - for fisher_factor in self.layer_collection.get_factors() - ] - inv_matrices = [ - matrix - for fisher_factor in self.layer_collection.get_factors() - for matrix in fisher_factor._matpower_by_exp_and_damping.values() - ] - self.assertEqual(cov_matrices[0].device, "/device:CPU:0") - self.assertEqual(cov_matrices[1].device, "/device:CPU:1") - # Inverse matrices need to be explicitly placed. - self.assertEqual(inv_matrices[0].device, "") - self.assertEqual(inv_matrices[1].device, "") - - def test_inv_update_thunks(self): - """Ensures inverse update ops run once per global_step.""" - with self._graph.as_default(), self.cached_session() as sess: - fisher_estimator = estimator.FisherEstimatorRoundRobin( - variables=[self.weights], - layer_collection=self.layer_collection, - damping=0.2, - cov_ema_decay=0.0) - - # Construct op that updates one inverse per global step. - global_step = training_util.get_or_create_global_step() - (cov_variable_thunks, _, inv_variable_thunks, - inv_update_op_thunks) = fisher_estimator.create_ops_and_vars_thunks() - for thunk in cov_variable_thunks: - thunk() - for thunk in inv_variable_thunks: - thunk() - inv_matrices = [ - matrix - for fisher_factor in self.layer_collection.get_factors() - for matrix in fisher_factor._matpower_by_exp_and_damping.values() - ] - inv_update_op = control_flow_ops.case( - [(math_ops.equal(global_step, i), thunk) - for i, thunk in enumerate(inv_update_op_thunks)]) - increment_global_step = global_step.assign_add(1) - - sess.run(variables.global_variables_initializer()) - initial_inv_values = sess.run(inv_matrices) - - # Ensure there's one update per inverse matrix. This is true as long as - # there's no fan-in/fan-out or parameter re-use. - self.assertEqual(len(inv_matrices), len(inv_update_op_thunks)) - - # Test is no-op if only 1 invariance matrix. - assert len(inv_matrices) > 1 - - # Assign each covariance matrix a value other than the identity. This - # ensures that the inverse matrices are updated to something different as - # well. - cov_matrices = [ - fisher_factor.get_cov() - for fisher_factor in self.layer_collection.get_factors() - ] - sess.run([ - cov_matrix.assign(2 * linalg_ops.eye(int(cov_matrix.shape[0]))) - for cov_matrix in cov_matrices - ]) - - for i in range(len(inv_matrices)): - # Compare new and old inverse values - new_inv_values = sess.run(inv_matrices) - is_inv_equal = [ - np.allclose(initial_inv_value, new_inv_value) - for (initial_inv_value, - new_inv_value) in zip(initial_inv_values, new_inv_values) - ] - num_inv_equal = sum(is_inv_equal) - - # Ensure exactly one inverse matrix changes per step. - self.assertEqual(num_inv_equal, len(inv_matrices) - i) - - # Run all inverse update ops. - sess.run(inv_update_op) - sess.run(increment_global_step) - - -if __name__ == "__main__": - test.main() diff --git a/tensorflow/contrib/kfac/python/kernel_tests/fisher_blocks_test.py b/tensorflow/contrib/kfac/python/kernel_tests/fisher_blocks_test.py deleted file mode 100644 index f845def507..0000000000 --- a/tensorflow/contrib/kfac/python/kernel_tests/fisher_blocks_test.py +++ /dev/null @@ -1,1018 +0,0 @@ -# Copyright 2017 The TensorFlow Authors. All Rights Reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -# ============================================================================== -"""Tests for tf.contrib.kfac.fisher_blocks.""" - -from __future__ import absolute_import -from __future__ import division -from __future__ import print_function - -import numpy as np - -from tensorflow.contrib.kfac.python.ops import fisher_blocks as fb -from tensorflow.contrib.kfac.python.ops import fisher_factors as ff -from tensorflow.contrib.kfac.python.ops import layer_collection as lc -from tensorflow.contrib.kfac.python.ops import linear_operator as lo -from tensorflow.contrib.kfac.python.ops import utils -from tensorflow.python.framework import ops -from tensorflow.python.framework import random_seed -from tensorflow.python.ops import array_ops -from tensorflow.python.ops import linalg_ops -from tensorflow.python.ops import math_ops -from tensorflow.python.ops import random_ops -from tensorflow.python.ops import state_ops -from tensorflow.python.ops import variables as tf_variables -from tensorflow.python.platform import test - - -# We need to set these constants since the numerical values used in the tests -# were chosen when these used to be the defaults. -ff.set_global_constants(init_covariances_at_zero=False, - zero_debias=False, - init_inverses_at_zero=False) - -# TODO(b/78538100): As far as I can tell, all the tests that say "Make sure our -# inverse is something other than the identity" are actually broken. They never -# run the covariance update ops and so the inverse actually is the identity -# (possible plus the damping term, which would still make it a multiple of the -# identity). - - -def _make_psd(dim): - """Constructs a PSD matrix of the given dimension.""" - mat = np.ones((dim, dim), dtype=np.float32) - mat[np.arange(dim), np.arange(dim)] = 2. + np.arange(dim) - return array_ops.constant(mat) - - -class UtilsTest(test.TestCase): - - def testComputePiTracenorm(self): - with ops.Graph().as_default(), self.cached_session() as sess: - random_seed.set_random_seed(200) - diag = ops.convert_to_tensor([1., 2., 0., 1.]) - left_factor = lo.LinearOperatorDiag(diag) - right_factor = lo.LinearOperatorFullMatrix(array_ops.ones([2, 2])) - - # pi is the sqrt of the left trace norm divided by the right trace norm - pi = fb.compute_pi_tracenorm(left_factor, right_factor) - - pi_val = sess.run(pi) - self.assertEqual(1., pi_val) - - -class FullFBTest(test.TestCase): - - def testFullFBInitSingleTensor(self): - with ops.Graph().as_default(): - random_seed.set_random_seed(200) - params = (array_ops.constant([1., 2.]), array_ops.constant(3.)) - block = fb.FullFB(lc.LayerCollection(), params) - block.register_additional_tower(32) - - self.assertAllEqual(params, block.tensors_to_compute_grads()) - - def testFullFBInitTensorTuple(self): - with ops.Graph().as_default(): - random_seed.set_random_seed(200) - params = (array_ops.constant([1., 2.]), array_ops.constant(3.)) - block = fb.FullFB(lc.LayerCollection(), params) - block.register_additional_tower(32) - - self.assertAllEqual(params, block.tensors_to_compute_grads()) - - def testInstantiateFactors(self): - with ops.Graph().as_default(): - random_seed.set_random_seed(200) - params = (array_ops.constant([1., 2.]), array_ops.constant(3.)) - block = fb.FullFB(lc.LayerCollection(), params) - block.register_additional_tower(32) - - grads = (params[0]**2, math_ops.sqrt(params[1])) - block.instantiate_factors(grads, 0.5) - - def testMultiplyInverseTuple(self): - with ops.Graph().as_default(), self.cached_session() as sess: - random_seed.set_random_seed(200) - params = (array_ops.constant([1., 2.]), array_ops.constant(3.)) - block = fb.FullFB(lc.LayerCollection(), params) - block.register_additional_tower(32) - grads = (params[0]**2, math_ops.sqrt(params[1])) - block.instantiate_factors((grads,), 0.5) - block._factor.instantiate_cov_variables() - block.register_inverse() - block._factor.instantiate_inv_variables() - - # Make sure our inverse is something other than the identity. - sess.run(tf_variables.global_variables_initializer()) - sess.run(block._factor.make_inverse_update_ops()) - - vector = array_ops.ones(3,) * 2 - output = block.multiply_inverse(vector) - - self.assertAllClose(sess.run(vector * 2 / 3.), sess.run(output)) - - def testMultiplyInverseNotTuple(self): - with ops.Graph().as_default(), self.cached_session() as sess: - random_seed.set_random_seed(200) - params = array_ops.constant([[1.], [2.]]) - block = fb.FullFB(lc.LayerCollection(), params) - block.register_additional_tower(32) - grads = params**2 - block.instantiate_factors((grads,), 0.5) - block._factor.instantiate_cov_variables() - block.register_inverse() - block._factor.instantiate_inv_variables() - - # Make sure our inverse is something other than the identity. - sess.run(tf_variables.global_variables_initializer()) - sess.run(block._factor.make_inverse_update_ops()) - - vector = array_ops.ones(2,) * 2 - output = block.multiply_inverse(vector) - - self.assertAllClose(sess.run(vector * 2 / 3.), sess.run(output)) - - def testMultiplyInverseAgainstExplicit(self): - with ops.Graph().as_default(), self.cached_session() as sess: - random_seed.set_random_seed(200) - params = (array_ops.constant([1., 2.]), array_ops.constant(3.)) - block = fb.FullFB(lc.LayerCollection(), params) - block.register_additional_tower(32) - grads = (array_ops.constant([2., 3.]), array_ops.constant(4.)) - damping = 0.5 - block.instantiate_factors((grads,), damping) - block._factor.instantiate_cov_variables() - block.register_inverse() - block._factor.instantiate_inv_variables() - - # Make sure our inverse is something other than the identity. - sess.run(state_ops.assign(block._factor._cov, _make_psd(3))) - sess.run(block._factor.make_inverse_update_ops()) - - v_flat = np.array([4., 5., 6.], dtype=np.float32) - vector = utils.column_to_tensors(params, array_ops.constant(v_flat)) - output = block.multiply_inverse(vector) - output_flat = sess.run(utils.tensors_to_column(output)).ravel() - - full = sess.run(block.full_fisher_block()) - explicit = np.dot(np.linalg.inv(full + damping * np.eye(3)), v_flat) - - self.assertAllClose(output_flat, explicit) - - -class NaiveDiagonalFBTest(test.TestCase): - - def testNaiveDiagonalFBInitSingleTensor(self): - with ops.Graph().as_default(): - random_seed.set_random_seed(200) - params = (array_ops.constant([1., 2.]), array_ops.constant(3.)) - block = fb.NaiveDiagonalFB(lc.LayerCollection(), params) - block.register_additional_tower(32) - - self.assertAllEqual(params, block.tensors_to_compute_grads()) - - def testNaiveDiagonalFBInitTensorTuple(self): - with ops.Graph().as_default(): - random_seed.set_random_seed(200) - params = (array_ops.constant([1., 2.]), array_ops.constant(3.)) - block = fb.NaiveDiagonalFB(lc.LayerCollection(), params) - block.register_additional_tower(32) - - self.assertAllEqual(params, block.tensors_to_compute_grads()) - - def testInstantiateFactors(self): - with ops.Graph().as_default(): - random_seed.set_random_seed(200) - params = (array_ops.constant([1., 2.]), array_ops.constant(3.)) - block = fb.NaiveDiagonalFB(lc.LayerCollection(), params) - block.register_additional_tower(32) - - grads = (params[0]**2, math_ops.sqrt(params[1])) - block.instantiate_factors(grads, 0.5) - - def testMultiplyInverseTuple(self): - with ops.Graph().as_default(), self.cached_session() as sess: - random_seed.set_random_seed(200) - params = (array_ops.constant([1., 2.]), array_ops.constant(3.)) - block = fb.NaiveDiagonalFB(lc.LayerCollection(), params) - block.register_additional_tower(32) - grads = (params[0]**2, math_ops.sqrt(params[1])) - block.instantiate_factors((grads,), 0.5) - block._factor.instantiate_cov_variables() - - # Make sure our inverse is something other than the identity. - sess.run(tf_variables.global_variables_initializer()) - sess.run(block._factor.make_inverse_update_ops()) - - vector = array_ops.ones(3,) * 2 - output = block.multiply_inverse(vector) - - self.assertAllClose(sess.run(vector * 2 / 3.), sess.run(output)) - - def testMultiplyInverseNotTuple(self): - with ops.Graph().as_default(), self.cached_session() as sess: - random_seed.set_random_seed(200) - params = array_ops.constant([[1.], [2.]]) - block = fb.NaiveDiagonalFB(lc.LayerCollection(), params) - block.register_additional_tower(32) - grads = params**2 - block.instantiate_factors((grads,), 0.5) - block._factor.instantiate_cov_variables() - - # Make sure our inverse is something other than the identity. - sess.run(tf_variables.global_variables_initializer()) - sess.run(block._factor.make_inverse_update_ops()) - vector = array_ops.ones(2,) * 2 - output = block.multiply_inverse(vector) - - self.assertAllClose(sess.run(vector * 2 / 3.), sess.run(output)) - - def testMultiplyInverseAgainstExplicit(self): - with ops.Graph().as_default(), self.cached_session() as sess: - random_seed.set_random_seed(200) - params = (array_ops.constant([1., 2.]), array_ops.constant(3.)) - block = fb.NaiveDiagonalFB(lc.LayerCollection(), params) - block.register_additional_tower(32) - grads = (params[0]**2, math_ops.sqrt(params[1])) - damping = 0.5 - block.instantiate_factors((grads,), damping) - block._factor.instantiate_cov_variables() - - cov = array_ops.reshape(array_ops.constant([2., 3., 4.]), [-1, 1]) - sess.run(state_ops.assign(block._factor._cov, cov)) - sess.run(block._factor.make_inverse_update_ops()) - - v_flat = np.array([4., 5., 6.], dtype=np.float32) - vector = utils.column_to_tensors(params, array_ops.constant(v_flat)) - output = block.multiply_inverse(vector) - output_flat = sess.run(utils.tensors_to_column(output)).ravel() - - full = sess.run(block.full_fisher_block()) - explicit = np.dot(np.linalg.inv(full + damping * np.eye(3)), v_flat) - self.assertAllClose(output_flat, explicit) - - -class FullyConnectedDiagonalFBTest(test.TestCase): - - def setUp(self): - super(FullyConnectedDiagonalFBTest, self).setUp() - - self.batch_size = 4 - self.input_size = 6 - self.output_size = 3 - - self.inputs = np.random.randn(self.batch_size, self.input_size).astype( - np.float32) - self.outputs = np.zeros([self.batch_size, self.output_size]).astype( - np.float32) - self.output_grads = np.random.randn(self.batch_size, - self.output_size).astype(np.float32) - self.w = np.random.randn(self.input_size, self.output_size).astype( - np.float32) - self.b = np.random.randn(self.output_size).astype(np.float32) - - def fisherApprox(self, has_bias=False): - """Fisher approximation using default inputs.""" - if has_bias: - inputs = np.concatenate( - [self.inputs, np.ones([self.batch_size, 1])], axis=1) - else: - inputs = self.inputs - return self.buildDiagonalFisherApproximation(inputs, self.output_grads) - - def buildDiagonalFisherApproximation(self, inputs, output_grads): - """Builds explicit diagonal Fisher approximation. - - Fisher's diagonal is (d loss / d w)'s elements squared for - d/dw = E[outer(input, output_grad)] - - where the expectation is taken over examples. - - Args: - inputs: np.array of shape [batch_size, input_size]. - output_grads: np.array of shape [batch_size, output_size]. - - Returns: - Diagonal np.array of shape [num_params, num_params] for num_params = - input_size * output_size. - """ - batch_size = inputs.shape[0] - assert output_grads.shape[0] == batch_size - input_size = inputs.shape[1] - output_size = output_grads.shape[1] - fisher_diag = np.zeros((input_size, output_size)) - for i in range(batch_size): - fisher_diag += np.square(np.outer(inputs[i], output_grads[i])) - return np.diag(fisher_diag.flatten()) / batch_size - - def testMultiply(self): - result, _ = self.runFisherBlockOps(self.w, [self.inputs], [self.outputs], - [self.output_grads]) - - # Construct Fisher-vector product. - expected_result = self.fisherApprox().dot(self.w.flatten()) - expected_result = expected_result.reshape( - [self.input_size, self.output_size]) - - self.assertAllClose(expected_result, result) - - def testMultiplyInverse(self): - _, result = self.runFisherBlockOps(self.w, [self.inputs], [self.outputs], - [self.output_grads]) - - # Construct inverse Fisher-vector product. - expected_result = np.linalg.inv(self.fisherApprox()).dot(self.w.flatten()) - expected_result = expected_result.reshape( - [self.input_size, self.output_size]) - - self.assertAllClose(expected_result, result) - - def testRegisterAdditionalTower(self): - """Ensure 1 big tower and 2 small towers are equivalent.""" - multiply_result_big, multiply_inverse_result_big = self.runFisherBlockOps( - self.w, [self.inputs], [self.outputs], [self.output_grads]) - multiply_result_small, multiply_inverse_result_small = ( - self.runFisherBlockOps(self.w, np.split(self.inputs, 2), - np.split(self.outputs, 2), - np.split(self.output_grads, 2))) - - self.assertAllClose(multiply_result_big, multiply_result_small) - self.assertAllClose(multiply_inverse_result_big, - multiply_inverse_result_small) - - def testMultiplyHasBias(self): - result, _ = self.runFisherBlockOps((self.w, self.b), [self.inputs], - [self.outputs], [self.output_grads]) - expected_result = self.fisherApprox(True).dot( - np.concatenate([self.w.flatten(), self.b.flatten()])) - expected_result = expected_result.reshape( - [self.input_size + 1, self.output_size]) - expected_result = (expected_result[:-1], expected_result[-1]) - - self.assertEqual(len(result), 2) - self.assertAllClose(expected_result[0], result[0]) - self.assertAllClose(expected_result[1], result[1]) - - def runFisherBlockOps(self, params, inputs, outputs, output_grads): - """Run Ops guaranteed by FisherBlock interface. - - Args: - params: Tensor or 2-tuple of Tensors. Represents weights or weights and - bias of this layer. - inputs: list of Tensors of shape [batch_size, input_size]. Inputs to - layer. - outputs: list of Tensors of shape [batch_size, output_size]. - Preactivations produced by layer. - output_grads: list of Tensors of shape [batch_size, output_size]. - Gradient of loss with respect to 'outputs'. - - Returns: - multiply_result: Result of FisherBlock.multiply(params) - multiply_inverse_result: Result of FisherBlock.multiply_inverse(params) - """ - with ops.Graph().as_default(), self.cached_session() as sess: - inputs = as_tensors(inputs) - outputs = as_tensors(outputs) - output_grads = as_tensors(output_grads) - params = as_tensors(params) - - block = fb.FullyConnectedDiagonalFB( - lc.LayerCollection(), has_bias=isinstance(params, (tuple, list))) - for (i, o) in zip(inputs, outputs): - block.register_additional_tower(i, o) - - block.instantiate_factors((output_grads,), damping=0.0) - block._factor.instantiate_cov_variables() - - sess.run(tf_variables.global_variables_initializer()) - sess.run(block._factor.make_covariance_update_op(0.0)) - multiply_result = sess.run(block.multiply(params)) - multiply_inverse_result = sess.run(block.multiply_inverse(params)) - - return multiply_result, multiply_inverse_result - - -class EmbeddingKFACFBTest(test.TestCase): - - def testInstantiateFactors(self): - with ops.Graph().as_default(): - random_seed.set_random_seed(200) - - # Create a Fisher Block. - vocab_size = 5 - block = fb.EmbeddingKFACFB(lc.LayerCollection(), vocab_size) - - # Add some examples. - inputs = array_ops.constant([[0, 1], [1, 2], [2, 3]]) - outputs = array_ops.constant([[0.], [1.], [2.]]) - block.register_additional_tower(inputs, outputs) - - # Instantiate factor's variables. Ensure it doesn't fail. - grads = outputs**2. - damping = array_ops.constant(0.) - block.instantiate_factors(((grads,),), damping) - - def testMultiplyInverse(self): - with ops.Graph().as_default(), self.cached_session() as sess: - random_seed.set_random_seed(200) - - # Create a Fisher Block. - vocab_size = 5 - block = fb.EmbeddingKFACFB(lc.LayerCollection(), vocab_size) - - # Add some examples. - inputs = array_ops.constant([[0, 1], [1, 2], [2, 3]]) - outputs = array_ops.constant([[0.], [1.], [2.]]) - block.register_additional_tower(inputs, outputs) - - # Instantiate factor's variables. Ensure it doesn't fail. - grads = outputs**2. - damping = array_ops.constant(0.) - block.instantiate_factors(((grads,),), damping) - block._input_factor.instantiate_cov_variables() - block._output_factor.instantiate_cov_variables() - block.register_inverse() - block._input_factor.instantiate_inv_variables() - block._output_factor.instantiate_inv_variables() - - # Create a sparse update. - indices = array_ops.constant([1, 3, 4]) - values = array_ops.constant([[1.], [1.], [1.]]) - sparse_vector = ops.IndexedSlices( - values, indices, dense_shape=[vocab_size, 1]) - dense_vector = array_ops.reshape([0., 1., 0., 1., 1.], [vocab_size, 1]) - - # Compare Fisher-vector product against explicit result. - result = block.multiply_inverse(sparse_vector) - expected_result = linalg_ops.matrix_solve(block.full_fisher_block(), - dense_vector) - - sess.run(tf_variables.global_variables_initializer()) - self.assertAlmostEqual( - sess.run(expected_result[1]), sess.run(result.values[0])) - self.assertAlmostEqual( - sess.run(expected_result[3]), sess.run(result.values[1])) - self.assertAlmostEqual( - sess.run(expected_result[4]), sess.run(result.values[2])) - - -class FullyConnectedKFACBasicFBTest(test.TestCase): - - def testFullyConnectedKFACBasicFBInit(self): - with ops.Graph().as_default(): - random_seed.set_random_seed(200) - inputs = array_ops.constant([1., 2.]) - outputs = array_ops.constant([3., 4.]) - block = fb.FullyConnectedKFACBasicFB(lc.LayerCollection()) - block.register_additional_tower(inputs, outputs) - - self.assertAllEqual([outputs], block.tensors_to_compute_grads()) - - def testInstantiateFactorsHasBias(self): - with ops.Graph().as_default(): - random_seed.set_random_seed(200) - inputs = array_ops.constant([[1., 2.], [3., 4.]]) - outputs = array_ops.constant([[3., 4.], [5., 6.]]) - block = fb.FullyConnectedKFACBasicFB(lc.LayerCollection(), has_bias=True) - block.register_additional_tower(inputs, outputs) - - grads = outputs**2 - block.instantiate_factors(((grads,),), 0.5) - - def testInstantiateFactorsNoBias(self): - with ops.Graph().as_default(): - random_seed.set_random_seed(200) - inputs = array_ops.constant([[1., 2.], [3., 4.]]) - outputs = array_ops.constant([[3., 4.], [5., 6.]]) - block = fb.FullyConnectedKFACBasicFB(lc.LayerCollection(), has_bias=False) - block.register_additional_tower(inputs, outputs) - - grads = outputs**2 - block.instantiate_factors(((grads,),), 0.5) - - def testMultiplyInverseTuple(self): - with ops.Graph().as_default(), self.cached_session() as sess: - random_seed.set_random_seed(200) - inputs = array_ops.constant([[1., 2., 3.], [3., 4., 5.], [5., 6., 7.]]) - outputs = array_ops.constant([[3., 4.], [5., 6.]]) - block = fb.FullyConnectedKFACBasicFB(lc.LayerCollection(), has_bias=False) - block.register_additional_tower(inputs, outputs) - grads = outputs**2 - block.instantiate_factors(((grads,),), 0.5) - - block._input_factor.instantiate_cov_variables() - block._output_factor.instantiate_cov_variables() - block.register_inverse() - block._input_factor.instantiate_inv_variables() - block._output_factor.instantiate_inv_variables() - - # Make sure our inverse is something other than the identity. - sess.run(tf_variables.global_variables_initializer()) - sess.run(block._input_factor.make_inverse_update_ops()) - sess.run(block._output_factor.make_inverse_update_ops()) - - vector = ( - np.arange(2, 6).reshape(2, 2).astype(np.float32), # - np.arange(1, 3).reshape(2, 1).astype(np.float32)) - output = block.multiply_inverse((array_ops.constant(vector[0]), - array_ops.constant(vector[1]))) - - output = sess.run(output) - self.assertAllClose([[0.686291, 1.029437], [1.372583, 1.715729]], - output[0]) - self.assertAllClose([0.343146, 0.686291], output[1]) - - def testMultiplyInverseNotTuple(self): - with ops.Graph().as_default(), self.cached_session() as sess: - random_seed.set_random_seed(200) - inputs = array_ops.constant([[1., 2.], [3., 4.]]) - outputs = array_ops.constant([[3., 4.], [5., 6.]]) - block = fb.FullyConnectedKFACBasicFB(lc.LayerCollection(), has_bias=False) - block.register_additional_tower(inputs, outputs) - grads = outputs**2 - block.instantiate_factors(((grads,),), 0.5) - block._input_factor.instantiate_cov_variables() - block._output_factor.instantiate_cov_variables() - block.register_inverse() - block._input_factor.instantiate_inv_variables() - block._output_factor.instantiate_inv_variables() - - # Make sure our inverse is something other than the identity. - sess.run(tf_variables.global_variables_initializer()) - sess.run(block._input_factor.make_inverse_update_ops()) - sess.run(block._output_factor.make_inverse_update_ops()) - - vector = np.arange(2, 6).reshape(2, 2).astype(np.float32) - output = block.multiply_inverse(array_ops.constant(vector)) - - self.assertAllClose([[0.686291, 1.029437], [1.372583, 1.715729]], - sess.run(output)) - - def testMultiplyInverseAgainstExplicit(self): - with ops.Graph().as_default(), self.cached_session() as sess: - random_seed.set_random_seed(200) - input_dim, output_dim = 3, 2 - inputs = array_ops.zeros([32, input_dim]) - outputs = array_ops.zeros([32, output_dim]) - params = array_ops.zeros([input_dim, output_dim]) - block = fb.FullyConnectedKFACBasicFB(lc.LayerCollection(), has_bias=False) - block.register_additional_tower(inputs, outputs) - grads = outputs**2 - damping = 0. # This test is only valid without damping. - block.instantiate_factors(((grads,),), damping) - block._input_factor.instantiate_cov_variables() - block._output_factor.instantiate_cov_variables() - - sess.run(state_ops.assign(block._input_factor._cov, _make_psd(3))) - sess.run(state_ops.assign(block._output_factor._cov, _make_psd(2))) - - block.register_inverse() - block._input_factor.instantiate_inv_variables() - block._output_factor.instantiate_inv_variables() - - sess.run(block._input_factor.make_inverse_update_ops()) - sess.run(block._output_factor.make_inverse_update_ops()) - - v_flat = np.arange(6, dtype=np.float32) - vector = utils.column_to_tensors(params, array_ops.constant(v_flat)) - output = block.multiply_inverse(vector) - output_flat = sess.run(utils.tensors_to_column(output)).ravel() - - full = sess.run(block.full_fisher_block()) - explicit = np.dot(np.linalg.inv(full + damping * np.eye(6)), v_flat) - - self.assertAllClose(output_flat, explicit) - - -class ConvDiagonalFBTest(test.TestCase): - - def setUp(self): - super(ConvDiagonalFBTest, self).setUp() - - self.batch_size = 2 - self.height = 8 - self.width = 4 - self.input_channels = 6 - self.output_channels = 3 - self.kernel_size = 1 - - self.inputs = np.random.randn(self.batch_size, self.height, self.width, - self.input_channels).astype(np.float32) - self.outputs = np.zeros( - [self.batch_size, self.height, self.width, - self.output_channels]).astype(np.float32) - self.output_grads = np.random.randn( - self.batch_size, self.height, self.width, self.output_channels).astype( - np.float32) - self.w = np.random.randn(self.kernel_size, self.kernel_size, - self.input_channels, self.output_channels).astype( - np.float32) - self.b = np.random.randn(self.output_channels).astype(np.float32) - - def fisherApprox(self, has_bias=False): - """Fisher approximation using default inputs.""" - if has_bias: - inputs = np.concatenate( - [self.inputs, - np.ones([self.batch_size, self.height, self.width, 1])], - axis=-1) - else: - inputs = self.inputs - return self.buildDiagonalFisherApproximation(inputs, self.output_grads, - self.kernel_size) - - def buildDiagonalFisherApproximation(self, inputs, output_grads, kernel_size): - r"""Builds explicit diagonal Fisher approximation. - - Fisher's diagonal is (d loss / d w)'s elements squared for - d/dw = E[\sum_{loc} outer(input_{loc}, output_grad_{loc})] - - where the expectation is taken over examples and the sum over (x, y) - locations upon which the convolution is applied. - - Args: - inputs: np.array of shape [batch_size, height, width, input_channels]. - output_grads: np.array of shape [batch_size, height, width, - output_channels]. - kernel_size: int. height and width of kernel. - - Returns: - Diagonal np.array of shape [num_params, num_params] for num_params = - kernel_size^2 * input_channels * output_channels. - """ - batch_size, height, width, input_channels = inputs.shape - assert output_grads.shape[0] == batch_size - assert output_grads.shape[1] == height - assert output_grads.shape[2] == width - output_channels = output_grads.shape[3] - - # If kernel_size == 1, then we don't need to worry about capturing context - # around the pixel upon which a convolution is applied. This makes testing - # easier. - assert kernel_size == 1, "kernel_size != 1 isn't supported." - num_locations = height * width - inputs = np.reshape(inputs, [batch_size, num_locations, input_channels]) - output_grads = np.reshape(output_grads, - [batch_size, num_locations, output_channels]) - - fisher_diag = np.zeros((input_channels, output_channels)) - for i in range(batch_size): - # Each example's approximation is a square(sum-of-outer-products). - example_fisher_diag = np.zeros((input_channels, output_channels)) - for j in range(num_locations): - example_fisher_diag += np.outer(inputs[i, j], output_grads[i, j]) - fisher_diag += np.square(example_fisher_diag) - - # Normalize by batch_size (not num_locations). - return np.diag(fisher_diag.flatten()) / batch_size - - def testMultiply(self): - result, _ = self.runFisherBlockOps(self.w, [self.inputs], [self.outputs], - [self.output_grads]) - - # Construct Fisher-vector product. - expected_result = self.fisherApprox().dot(self.w.flatten()) - expected_result = expected_result.reshape([ - self.kernel_size, self.kernel_size, self.input_channels, - self.output_channels - ]) - - self.assertAllClose(expected_result, result) - - def testMultiplyInverse(self): - _, result = self.runFisherBlockOps(self.w, [self.inputs], [self.outputs], - [self.output_grads]) - - # Construct inverse Fisher-vector product. - expected_result = np.linalg.inv(self.fisherApprox()).dot(self.w.flatten()) - expected_result = expected_result.reshape([ - self.kernel_size, self.kernel_size, self.input_channels, - self.output_channels - ]) - - self.assertAllClose(expected_result, result, atol=1e-3) - - def testRegisterAdditionalTower(self): - """Ensure 1 big tower and 2 small towers are equivalent.""" - multiply_result_big, multiply_inverse_result_big = self.runFisherBlockOps( - self.w, [self.inputs], [self.outputs], [self.output_grads]) - multiply_result_small, multiply_inverse_result_small = ( - self.runFisherBlockOps(self.w, np.split(self.inputs, 2), - np.split(self.outputs, 2), - np.split(self.output_grads, 2))) - - self.assertAllClose(multiply_result_big, multiply_result_small) - self.assertAllClose(multiply_inverse_result_big, - multiply_inverse_result_small) - - def testMultiplyHasBias(self): - result, _ = self.runFisherBlockOps((self.w, self.b), [self.inputs], - [self.outputs], [self.output_grads]) - # Clone 'b' along 'input_channels' dimension. - b_filter = np.tile( - np.reshape(self.b, [1, 1, 1, self.output_channels]), - [self.kernel_size, self.kernel_size, 1, 1]) - params = np.concatenate([self.w, b_filter], axis=2) - expected_result = self.fisherApprox(True).dot(params.flatten()) - - # Extract 'b' from concatenated parameters. - expected_result = expected_result.reshape([ - self.kernel_size, self.kernel_size, self.input_channels + 1, - self.output_channels - ]) - expected_result = (expected_result[:, :, 0:-1, :], - np.reshape(expected_result[:, :, -1, :], - [self.output_channels])) - - self.assertEqual(len(result), 2) - self.assertAllClose(expected_result[0], result[0]) - self.assertAllClose(expected_result[1], result[1]) - - def runFisherBlockOps(self, params, inputs, outputs, output_grads): - """Run Ops guaranteed by FisherBlock interface. - - Args: - params: Tensor or 2-tuple of Tensors. Represents weights or weights and - bias of this layer. - inputs: list of Tensors of shape [batch_size, input_size]. Inputs to - layer. - outputs: list of Tensors of shape [batch_size, output_size]. - Preactivations produced by layer. - output_grads: list of Tensors of shape [batch_size, output_size]. - Gradient of loss with respect to 'outputs'. - - Returns: - multiply_result: Result of FisherBlock.multiply(params) - multiply_inverse_result: Result of FisherBlock.multiply_inverse(params) - """ - with ops.Graph().as_default(), self.cached_session() as sess: - inputs = as_tensors(inputs) - outputs = as_tensors(outputs) - output_grads = as_tensors(output_grads) - params = as_tensors(params) - - block = fb.ConvDiagonalFB( - lc.LayerCollection(), params, strides=[1, 1, 1, 1], padding='SAME') - for (i, o) in zip(inputs, outputs): - block.register_additional_tower(i, o) - - block.instantiate_factors((output_grads,), damping=0.0) - block._factor.instantiate_cov_variables() - - sess.run(tf_variables.global_variables_initializer()) - sess.run(block._factor.make_covariance_update_op(0.0)) - multiply_result = sess.run(block.multiply(params)) - multiply_inverse_result = sess.run(block.multiply_inverse(params)) - - return multiply_result, multiply_inverse_result - - -class DepthwiseConvKFCBasicFBTest(test.TestCase): - - def testInstantiateFactors(self): - with ops.Graph().as_default(): - random_seed.set_random_seed(200) - params = random_ops.random_normal((3, 3, 8, 2)) - inputs = random_ops.random_normal((32, 5, 5, 8)) - outputs = random_ops.random_normal((32, 5, 5, 16)) - layer_collection = lc.LayerCollection() - block = fb.DepthwiseConvKFCBasicFB( - layer_collection, params=params, strides=[1, 1, 1, 1], padding='SAME') - block.register_additional_tower(inputs, outputs) - grads = outputs**2 - block.instantiate_factors(([grads],), 0.5) - - def testMultiplyInverse(self): - with ops.Graph().as_default(), self.cached_session() as sess: - random_seed.set_random_seed(200) - params = random_ops.random_normal((3, 3, 8, 2)) - inputs = random_ops.random_normal((32, 5, 5, 8)) - outputs = random_ops.random_normal((32, 5, 5, 16)) - layer_collection = lc.LayerCollection() - block = fb.DepthwiseConvKFCBasicFB( - layer_collection, params=params, strides=[1, 1, 1, 1], padding='SAME') - block.register_additional_tower(inputs, outputs) - grads = outputs**2 - block.instantiate_factors(([grads],), 0.5) - block._input_factor.instantiate_cov_variables() - block._output_factor.instantiate_cov_variables() - block.register_inverse() - block._input_factor.instantiate_inv_variables() - block._output_factor.instantiate_inv_variables() - - # Ensure inverse update op doesn't crash. - sess.run(tf_variables.global_variables_initializer()) - sess.run([ - factor.make_inverse_update_ops() - for factor in layer_collection.get_factors() - ]) - - # Ensure inverse-vector multiply doesn't crash. - output = block.multiply_inverse(params) - sess.run(output) - - # Ensure same shape. - self.assertAllEqual(output.shape, params.shape) - - -class ConvKFCBasicFBTest(test.TestCase): - - def _testConvKFCBasicFBInitParams(self, params): - with ops.Graph().as_default(): - random_seed.set_random_seed(200) - if isinstance(params, (list, tuple)): - params = [array_ops.constant(param) for param in params] - else: - params = array_ops.constant(params) - inputs = random_ops.random_normal((2, 2, 2)) - outputs = random_ops.random_normal((2, 2, 2)) - block = fb.ConvKFCBasicFB( - lc.LayerCollection(), params=params, padding='SAME') - block.register_additional_tower(inputs, outputs) - - self.assertAllEqual([outputs], block.tensors_to_compute_grads()) - - def testConvKFCBasicFBInitParamsParamsTuple(self): - self._testConvKFCBasicFBInitParams([np.ones([1, 2, 2]), np.ones([2])]) - - def testConvKFCBasicFBInitParamsParamsSingle(self): - self._testConvKFCBasicFBInitParams([np.ones([1, 2, 2])]) - - def testMultiplyInverseTuple(self): - with ops.Graph().as_default(), self.cached_session() as sess: - random_seed.set_random_seed(200) - params = random_ops.random_normal((2, 2, 2, 2)) - inputs = random_ops.random_normal((2, 2, 2, 2)) - outputs = random_ops.random_normal((2, 2, 2, 2)) - block = fb.ConvKFCBasicFB( - lc.LayerCollection(), params=params, padding='SAME') - block.register_additional_tower(inputs, outputs) - grads = outputs**2 - block.instantiate_factors(((grads,),), 0.5) - block._input_factor.instantiate_cov_variables() - block._output_factor.instantiate_cov_variables() - block.register_inverse() - block._input_factor.instantiate_inv_variables() - block._output_factor.instantiate_inv_variables() - - # Make sure our inverse is something other than the identity. - sess.run(tf_variables.global_variables_initializer()) - sess.run(block._input_factor.make_inverse_update_ops()) - sess.run(block._output_factor.make_inverse_update_ops()) - - vector = (np.arange(1, 15).reshape(7, 2).astype(np.float32), - np.arange(2, 4).reshape(2, 1).astype(np.float32)) - output = block.multiply_inverse((array_ops.constant(vector[0]), - array_ops.constant(vector[1]))) - - output = sess.run(output) - self.assertAllClose([0.136455, 0.27291], output[0][0]) - self.assertAllClose([0.27291, 0.409365], output[1]) - - def testMultiplyInverseNotTuple(self): - with ops.Graph().as_default(), self.cached_session() as sess: - random_seed.set_random_seed(200) - params = random_ops.random_normal((2, 2, 2, 2)) - inputs = random_ops.random_normal((2, 2, 2, 2)) - outputs = random_ops.random_normal((2, 2, 2, 2)) - block = fb.ConvKFCBasicFB( - lc.LayerCollection(), params=params, padding='SAME') - block.register_additional_tower(inputs, outputs) - self.assertFalse(block._has_bias) - grads = outputs**2 - block.instantiate_factors(((grads,),), 0.5) - block._input_factor.instantiate_cov_variables() - block._output_factor.instantiate_cov_variables() - block.register_inverse() - block._input_factor.instantiate_inv_variables() - block._output_factor.instantiate_inv_variables() - - # Make sure our inverse is something other than the identity. - sess.run(tf_variables.global_variables_initializer()) - sess.run(block._input_factor.make_inverse_update_ops()) - sess.run(block._output_factor.make_inverse_update_ops()) - - vector = np.arange(1, 17).reshape(8, 2).astype(np.float32) - output = block.multiply_inverse(array_ops.constant(vector)) - - self.assertAllClose([0.136455, 0.27291], sess.run(output)[0]) - - def testMultiplyInverseNotTupleWithBias(self): - with ops.Graph().as_default(), self.cached_session() as sess: - random_seed.set_random_seed(200) - params = [random_ops.random_normal((2, 2, 2, 2))] - inputs = random_ops.random_normal((2, 2, 2, 2)) - outputs = random_ops.random_normal((2, 2, 2, 2)) - block = fb.ConvKFCBasicFB( - lc.LayerCollection(), params=params, padding='SAME') - block.register_additional_tower(inputs, outputs) - self.assertTrue(block._has_bias) - grads = outputs**2 - block.instantiate_factors(((grads,),), 0.5) - block._input_factor.instantiate_cov_variables() - block._output_factor.instantiate_cov_variables() - block.register_inverse() - block._input_factor.instantiate_inv_variables() - block._output_factor.instantiate_inv_variables() - - # Make sure our inverse is something other than the identity. - sess.run(tf_variables.global_variables_initializer()) - sess.run(block._input_factor.make_inverse_update_ops()) - sess.run(block._output_factor.make_inverse_update_ops()) - - vector = np.arange(1, 19).reshape(9, 2).astype(np.float32) - output = block.multiply_inverse(array_ops.constant(vector)) - - self.assertAllClose([0.136455, 0.27291], sess.run(output)[0]) - - def testMultiplyInverseAgainstExplicit(self): - with ops.Graph().as_default(), self.cached_session() as sess: - random_seed.set_random_seed(200) - params = array_ops.zeros((2, 2, 2, 2)) - inputs = array_ops.zeros((2, 2, 2, 2)) - outputs = array_ops.zeros((2, 2, 2, 2)) - block = fb.ConvKFCBasicFB( - lc.LayerCollection(), params=params, padding='SAME') - block.register_additional_tower(inputs, outputs) - grads = outputs**2 - damping = 0. # This test is only valid without damping. - block.instantiate_factors(((grads,),), damping) - block._input_factor.instantiate_cov_variables() - block._output_factor.instantiate_cov_variables() - block.register_inverse() - block._input_factor.instantiate_inv_variables() - block._output_factor.instantiate_inv_variables() - - sess.run(state_ops.assign(block._input_factor._cov, _make_psd(8))) - sess.run(state_ops.assign(block._output_factor._cov, _make_psd(2))) - sess.run(block._input_factor.make_inverse_update_ops()) - sess.run(block._output_factor.make_inverse_update_ops()) - - v_flat = np.arange(16, dtype=np.float32) - vector = utils.column_to_tensors(params, array_ops.constant(v_flat)) - output = block.multiply_inverse(vector) - output_flat = sess.run(utils.tensors_to_column(output)).ravel() - - full = sess.run(block.full_fisher_block()) - explicit = np.dot(np.linalg.inv(full + damping * np.eye(16)), v_flat) - - self.assertAllClose(output_flat, explicit) - - -class FullyConnectedSeriesFBTest(test.TestCase): - - def testFullyConnectedSeriesFBInit(self): - with ops.Graph().as_default(): - random_seed.set_random_seed(200) - inputs = array_ops.constant([1., 2.]) - outputs = array_ops.constant([3., 4.]) - block = fb.FullyConnectedSeriesFB(lc.LayerCollection()) - block.register_additional_tower([inputs], [outputs]) - self.assertAllEqual([[outputs]], block.tensors_to_compute_grads()) - - def testInstantiateFactorsHasBias(self): - with ops.Graph().as_default(): - random_seed.set_random_seed(200) - inputs = array_ops.constant([[1., 2.], [3., 4.]]) - outputs = array_ops.constant([[3., 4.], [5., 6.]]) - block = fb.FullyConnectedSeriesFB( - lc.LayerCollection(), - has_bias=True) - block.register_additional_tower([inputs], [outputs]) - grads = outputs**2 - block.instantiate_factors((((grads,),),), 0.5) - - def testInstantiateFactorsNoBias(self): - with ops.Graph().as_default(): - random_seed.set_random_seed(200) - inputs = array_ops.constant([[1., 2.], [3., 4.]]) - outputs = array_ops.constant([[3., 4.], [5., 6.]]) - block = fb.FullyConnectedSeriesFB( - lc.LayerCollection(), - has_bias=False) - block.register_additional_tower([inputs], [outputs]) - grads = outputs**2 - block.instantiate_factors((((grads,),),), 0.5) - - -def as_tensors(tensor_or_tuple): - """Converts a potentially nested tuple of np.array to Tensors.""" - if isinstance(tensor_or_tuple, (tuple, list)): - return tuple(as_tensors(t) for t in tensor_or_tuple) - return ops.convert_to_tensor(tensor_or_tuple) - - -if __name__ == '__main__': - test.main() diff --git a/tensorflow/contrib/kfac/python/kernel_tests/fisher_factors_test.py b/tensorflow/contrib/kfac/python/kernel_tests/fisher_factors_test.py deleted file mode 100644 index a396ca3f85..0000000000 --- a/tensorflow/contrib/kfac/python/kernel_tests/fisher_factors_test.py +++ /dev/null @@ -1,955 +0,0 @@ -# Copyright 2017 The TensorFlow Authors. All Rights Reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -# ============================================================================== -"""Tests for tf.contrib.kfac.fisher_factors.""" - -from __future__ import absolute_import -from __future__ import division -from __future__ import print_function - -import numpy as np -import numpy.random as npr - -from tensorflow.contrib.kfac.python.ops import fisher_blocks as fb -from tensorflow.contrib.kfac.python.ops import fisher_factors as ff -from tensorflow.python.framework import constant_op -from tensorflow.python.framework import dtypes -from tensorflow.python.framework import ops as tf_ops -from tensorflow.python.framework import random_seed -from tensorflow.python.ops import array_ops -from tensorflow.python.ops import gradients_impl -from tensorflow.python.ops import math_ops -from tensorflow.python.ops import random_ops -from tensorflow.python.ops import variables as tf_variables -from tensorflow.python.platform import test - - -# We need to set these constants since the numerical values used in the tests -# were chosen when these used to be the defaults. -ff.set_global_constants(init_covariances_at_zero=False, - zero_debias=False, - init_inverses_at_zero=False) - - -def make_damping_func(damping): - return fb._package_func(lambda: damping, damping) - - -class FisherFactorTestingDummy(ff.FisherFactor): - """Dummy class to test the non-abstract methods on ff.FisherFactor.""" - - @property - def _var_scope(self): - return 'dummy/a_b_c' - - @property - def _cov_shape(self): - raise NotImplementedError - - @property - def _num_sources(self): - return 1 - - @property - def _dtype(self): - return dtypes.float32 - - def _compute_new_cov(self): - raise NotImplementedError - - def instantiate_covariance(self): - pass - - def make_inverse_update_ops(self): - return [] - - def get_cov(self): - return NotImplementedError - - def instantiate_inv_variables(self): - return NotImplementedError - - def _num_towers(self): - raise NotImplementedError - - def _get_data_device(self): - raise NotImplementedError - - def register_matpower(self, exp, damping_func): - raise NotImplementedError - - def register_cholesky(self, damping_func): - raise NotImplementedError - - def register_cholesky_inverse(self, damping_func): - raise NotImplementedError - - def get_matpower(self, exp, damping_func): - raise NotImplementedError - - def get_cholesky(self, damping_func): - raise NotImplementedError - - def get_cholesky_inverse(self, damping_func): - raise NotImplementedError - - def get_cov_as_linear_operator(self): - raise NotImplementedError - - -class DenseSquareMatrixFactorTestingDummy(ff.DenseSquareMatrixFactor): - """Dummy class to test the non-abstract methods on ff.DenseSquareMatrixFactor. - """ - - def __init__(self, shape): - self._shape = shape - super(DenseSquareMatrixFactorTestingDummy, self).__init__() - - @property - def _var_scope(self): - return 'dummy/a_b_c' - - @property - def _cov_shape(self): - return self._shape - - @property - def _num_sources(self): - return 1 - - @property - def _dtype(self): - return dtypes.float32 - - def _compute_new_cov(self): - raise NotImplementedError - - def instantiate_covariance(self): - pass - - def _num_towers(self): - raise NotImplementedError - - def _get_data_device(self): - raise NotImplementedError - - -class NumericalUtilsTest(test.TestCase): - - def testComputeCovAgainstNumpy(self): - with tf_ops.Graph().as_default(), self.cached_session() as sess: - npr.seed(0) - random_seed.set_random_seed(200) - - x = npr.randn(100, 3) - cov = ff.compute_cov(array_ops.constant(x)) - np_cov = np.dot(x.T, x) / x.shape[0] - - self.assertAllClose(sess.run(cov), np_cov) - - def testComputeCovAgainstNumpyWithAlternativeNormalizer(self): - with tf_ops.Graph().as_default(), self.cached_session() as sess: - npr.seed(0) - random_seed.set_random_seed(200) - - normalizer = 10. - x = npr.randn(100, 3) - cov = ff.compute_cov(array_ops.constant(x), normalizer=normalizer) - np_cov = np.dot(x.T, x) / normalizer - - self.assertAllClose(sess.run(cov), np_cov) - - def testAppendHomog(self): - with tf_ops.Graph().as_default(), self.cached_session() as sess: - npr.seed(0) - - m, n = 3, 4 - a = npr.randn(m, n) - a_homog = ff.append_homog(array_ops.constant(a)) - np_result = np.hstack([a, np.ones((m, 1))]) - - self.assertAllClose(sess.run(a_homog), np_result) - - -class NameStringUtilFunctionTest(test.TestCase): - - def _make_tensor(self): - x = array_ops.placeholder(dtypes.float64, (3, 1)) - w = array_ops.constant(npr.RandomState(0).randn(3, 3)) - y = math_ops.matmul(w, x) - g = gradients_impl.gradients(y, x)[0] - return g - - def testScopeStringFromParamsSingleTensor(self): - with tf_ops.Graph().as_default(): - g = self._make_tensor() - scope_string = ff.scope_string_from_params(g) - self.assertEqual('gradients_MatMul_grad_MatMul_1', scope_string) - - def testScopeStringFromParamsMultipleTensors(self): - with tf_ops.Graph().as_default(): - x = array_ops.constant(1,) - y = array_ops.constant(2,) - scope_string = ff.scope_string_from_params((x, y)) - self.assertEqual('Const_Const_1', scope_string) - - def testScopeStringFromParamsMultipleTypes(self): - with tf_ops.Graph().as_default(): - x = array_ops.constant(1,) - y = array_ops.constant(2,) - scope_string = ff.scope_string_from_params([[1, 2, 3], 'foo', True, 4, - (x, y)]) - self.assertEqual('1-2-3_foo_True_4_Const__Const_1', scope_string) - - def testScopeStringFromParamsUnsupportedType(self): - with tf_ops.Graph().as_default(): - x = array_ops.constant(1,) - y = array_ops.constant(2,) - unsupported = 1.2 # Floats are not supported. - with self.assertRaises(ValueError): - ff.scope_string_from_params([[1, 2, 3], 'foo', True, 4, (x, y), - unsupported]) - - def testScopeStringFromName(self): - with tf_ops.Graph().as_default(): - g = self._make_tensor() - scope_string = ff.scope_string_from_name(g) - self.assertEqual('gradients_MatMul_grad_MatMul_1', scope_string) - - def testScalarOrTensorToString(self): - with tf_ops.Graph().as_default(): - self.assertEqual(ff.scalar_or_tensor_to_string(5.), repr(5.)) - - g = self._make_tensor() - scope_string = ff.scope_string_from_name(g) - self.assertEqual(ff.scalar_or_tensor_to_string(g), scope_string) - - -class FisherFactorTest(test.TestCase): - - def testMakeInverseUpdateOps(self): - with tf_ops.Graph().as_default(): - random_seed.set_random_seed(200) - factor = FisherFactorTestingDummy() - - self.assertEqual(0, len(factor.make_inverse_update_ops())) - - -class DenseSquareMatrixFactorTest(test.TestCase): - - def testRegisterDampedInverse(self): - with tf_ops.Graph().as_default(): - random_seed.set_random_seed(200) - shape = [2, 2] - factor = DenseSquareMatrixFactorTestingDummy(shape) - factor_var_scope = 'dummy/a_b_c' - - damping_funcs = [make_damping_func(0.1), - make_damping_func(0.1), - make_damping_func(1e-5), - make_damping_func(1e-5)] - for damping_func in damping_funcs: - factor.register_inverse(damping_func) - - factor.instantiate_inv_variables() - - inv = factor.get_inverse(damping_funcs[0]).to_dense() - self.assertEqual(inv, factor.get_inverse(damping_funcs[1]).to_dense()) - self.assertNotEqual(inv, factor.get_inverse(damping_funcs[2]).to_dense()) - self.assertEqual(factor.get_inverse(damping_funcs[2]).to_dense(), - factor.get_inverse(damping_funcs[3]).to_dense()) - factor_vars = tf_ops.get_collection(tf_ops.GraphKeys.GLOBAL_VARIABLES, - factor_var_scope) - factor_tensors = (tf_ops.convert_to_tensor(var) for var in factor_vars) - - self.assertEqual(set([inv, - factor.get_inverse(damping_funcs[2]).to_dense()]), - set(factor_tensors)) - self.assertEqual(shape, inv.get_shape()) - - def testRegisterMatpower(self): - with tf_ops.Graph().as_default(): - random_seed.set_random_seed(200) - shape = [3, 3] - factor = DenseSquareMatrixFactorTestingDummy(shape) - factor_var_scope = 'dummy/a_b_c' - - # TODO(b/74201126): Change to using the same func for both once - # Topohash is in place. - damping_func_1 = make_damping_func(0.5) - damping_func_2 = make_damping_func(0.5) - - factor.register_matpower(-0.5, damping_func_1) - factor.register_matpower(2, damping_func_2) - - factor.instantiate_inv_variables() - - factor_vars = tf_ops.get_collection(tf_ops.GraphKeys.GLOBAL_VARIABLES, - factor_var_scope) - - factor_tensors = (tf_ops.convert_to_tensor(var) for var in factor_vars) - - matpower1 = factor.get_matpower(-0.5, damping_func_1).to_dense() - matpower2 = factor.get_matpower(2, damping_func_2).to_dense() - - self.assertEqual(set([matpower1, matpower2]), set(factor_tensors)) - - self.assertEqual(shape, matpower1.get_shape()) - self.assertEqual(shape, matpower2.get_shape()) - - def testMakeInverseUpdateOps(self): - with tf_ops.Graph().as_default(): - random_seed.set_random_seed(200) - factor = FisherFactorTestingDummy() - - self.assertEqual(0, len(factor.make_inverse_update_ops())) - - def testMakeInverseUpdateOpsManyInversesEigenDecomp(self): - with tf_ops.Graph().as_default(), self.cached_session() as sess: - random_seed.set_random_seed(200) - cov = np.array([[1., 2.], [3., 4.]]) - factor = DenseSquareMatrixFactorTestingDummy(cov.shape) - factor._cov = array_ops.constant(cov, dtype=dtypes.float32) - - damping_funcs = [] - for i in range(1, ff.EIGENVALUE_DECOMPOSITION_THRESHOLD + 1): - damping_funcs.append(make_damping_func(1./i)) - - for i in range(ff.EIGENVALUE_DECOMPOSITION_THRESHOLD): - factor.register_inverse(damping_funcs[i]) - - factor.instantiate_inv_variables() - ops = factor.make_inverse_update_ops() - self.assertEqual(1, len(ops)) - - sess.run(tf_variables.global_variables_initializer()) - new_invs = [] - sess.run(ops) - for i in range(ff.EIGENVALUE_DECOMPOSITION_THRESHOLD): - # The inverse op will assign the damped inverse of cov to the inv var. - new_invs.append( - sess.run(factor.get_inverse(damping_funcs[i]).to_dense())) - - # We want to see that the new invs are all different from each other. - for i in range(len(new_invs)): - for j in range(i + 1, len(new_invs)): - # Just check the first element. - self.assertNotEqual(new_invs[i][0][0], new_invs[j][0][0]) - - def testMakeInverseUpdateOpsMatPowerEigenDecomp(self): - with tf_ops.Graph().as_default(), self.cached_session() as sess: - random_seed.set_random_seed(200) - cov = np.array([[6., 2.], [2., 4.]]) - factor = DenseSquareMatrixFactorTestingDummy(cov.shape) - factor._cov = array_ops.constant(cov, dtype=dtypes.float32) - exp = 2 # NOTE(mattjj): must be int to test with np.linalg.matrix_power - damping = 0.5 - damping_func = make_damping_func(damping) - - factor.register_matpower(exp, damping_func) - factor.instantiate_inv_variables() - ops = factor.make_inverse_update_ops() - self.assertEqual(1, len(ops)) - - sess.run(tf_variables.global_variables_initializer()) - sess.run(ops[0]) - matpower = sess.run(factor.get_matpower(exp, damping_func).to_dense()) - matpower_np = np.linalg.matrix_power(cov + np.eye(2) * damping, exp) - self.assertAllClose(matpower, matpower_np) - - def testMakeInverseUpdateOpsNoEigenDecomp(self): - with tf_ops.Graph().as_default(), self.cached_session() as sess: - random_seed.set_random_seed(200) - cov = np.array([[5., 2.], [2., 4.]]) # NOTE(mattjj): must be symmetric - factor = DenseSquareMatrixFactorTestingDummy(cov.shape) - factor._cov = array_ops.constant(cov, dtype=dtypes.float32) - - damping_func = make_damping_func(0) - - factor.register_inverse(damping_func) - factor.instantiate_inv_variables() - ops = factor.make_inverse_update_ops() - self.assertEqual(1, len(ops)) - - sess.run(tf_variables.global_variables_initializer()) - # The inverse op will assign the damped inverse of cov to the inv var. - old_inv = sess.run(factor.get_inverse(damping_func).to_dense()) - self.assertAllClose( - sess.run(ff.inverse_initializer(cov.shape, dtypes.float32)), old_inv) - - sess.run(ops) - new_inv = sess.run(factor.get_inverse(damping_func).to_dense()) - self.assertAllClose(new_inv, np.linalg.inv(cov)) - - -class FullFactorTest(test.TestCase): - - def testFullFactorInit(self): - with tf_ops.Graph().as_default(): - random_seed.set_random_seed(200) - tensor = array_ops.ones((2, 3), name='a/b/c') - factor = ff.FullFactor((tensor,), 32) - factor.instantiate_cov_variables() - self.assertEqual([6, 6], factor.get_cov().get_shape().as_list()) - - def testFullFactorInitFloat64(self): - with tf_ops.Graph().as_default(): - dtype = dtypes.float64_ref - random_seed.set_random_seed(200) - tensor = array_ops.ones((2, 3), dtype=dtype, name='a/b/c') - factor = ff.FullFactor((tensor,), 32) - factor.instantiate_cov_variables() - cov = factor.get_cov() - self.assertEqual(cov.dtype, dtype) - self.assertEqual([6, 6], cov.get_shape().as_list()) - - def testMakeCovarianceUpdateOp(self): - with tf_ops.Graph().as_default(), self.cached_session() as sess: - random_seed.set_random_seed(200) - tensor = array_ops.constant([1., 2.], name='a/b/c') - factor = ff.FullFactor((tensor,), 2) - factor.instantiate_cov_variables() - - sess.run(tf_variables.global_variables_initializer()) - new_cov = sess.run(factor.make_covariance_update_op(.5)) - self.assertAllClose([[0.75, 0.5], [0.5, 1.5]], new_cov) - - -class NaiveDiagonalFactorTest(test.TestCase): - - def testNaiveDiagonalFactorInit(self): - with tf_ops.Graph().as_default(): - random_seed.set_random_seed(200) - tensor = array_ops.ones((2, 3), name='a/b/c') - factor = ff.NaiveDiagonalFactor((tensor,), 32) - factor.instantiate_cov_variables() - self.assertEqual([6, 1], factor.get_cov().get_shape().as_list()) - - def testNaiveDiagonalFactorInitFloat64(self): - with tf_ops.Graph().as_default(): - dtype = dtypes.float64_ref - random_seed.set_random_seed(200) - tensor = array_ops.ones((2, 3), dtype=dtype, name='a/b/c') - factor = ff.NaiveDiagonalFactor((tensor,), 32) - factor.instantiate_cov_variables() - cov = factor.get_cov() - self.assertEqual(cov.dtype, dtype) - self.assertEqual([6, 1], cov.get_shape().as_list()) - - def testMakeCovarianceUpdateOp(self): - with tf_ops.Graph().as_default(), self.cached_session() as sess: - random_seed.set_random_seed(200) - tensor = array_ops.constant([1., 2.], name='a/b/c') - factor = ff.NaiveDiagonalFactor((tensor,), 2) - factor.instantiate_cov_variables() - - sess.run(tf_variables.global_variables_initializer()) - new_cov = sess.run(factor.make_covariance_update_op(.5)) - self.assertAllClose([[0.75], [1.5]], new_cov) - - -class EmbeddingInputKroneckerFactorTest(test.TestCase): - - def testInitialization(self): - with tf_ops.Graph().as_default(): - input_ids = array_ops.constant([[0], [1], [4]]) - vocab_size = 5 - factor = ff.EmbeddingInputKroneckerFactor((input_ids,), vocab_size) - factor.instantiate_cov_variables() - cov = factor.get_cov() - self.assertEqual(cov.shape.as_list(), [vocab_size]) - - def testCovarianceUpdateOp(self): - with tf_ops.Graph().as_default(): - input_ids = array_ops.constant([[0], [1], [4]]) - vocab_size = 5 - factor = ff.EmbeddingInputKroneckerFactor((input_ids,), vocab_size) - factor.instantiate_cov_variables() - cov_update_op = factor.make_covariance_update_op(0.0) - - with self.cached_session() as sess: - sess.run(tf_variables.global_variables_initializer()) - new_cov = sess.run(cov_update_op) - self.assertAllClose(np.array([1., 1., 0., 0., 1.]) / 3., new_cov) - - -class ConvDiagonalFactorTest(test.TestCase): - - def setUp(self): - self.batch_size = 10 - self.height = self.width = 32 - self.in_channels = 3 - self.out_channels = 1 - self.kernel_height = self.kernel_width = 3 - self.strides = [1, 2, 2, 1] - self.data_format = 'NHWC' - self.padding = 'SAME' - self.kernel_shape = [ - self.kernel_height, self.kernel_width, self.in_channels, - self.out_channels - ] - - def testInit(self): - with tf_ops.Graph().as_default(): - inputs = random_ops.random_uniform( - [self.batch_size, self.height, self.width, self.in_channels]) - outputs_grads = [ - random_ops.random_uniform([ - self.batch_size, self.height // self.strides[1], - self.width // self.strides[2], self.out_channels - ]) for _ in range(3) - ] - - factor = ff.ConvDiagonalFactor( - (inputs,), - (outputs_grads,), - self.kernel_shape, - self.strides, - self.padding, - data_format=self.data_format) - factor.instantiate_cov_variables() - - # Ensure covariance matrix's shape makes sense. - self.assertEqual([ - self.kernel_height * self.kernel_width * self.in_channels, - self.out_channels - ], - factor.get_cov().shape.as_list()) - - def testMakeCovarianceUpdateOp(self): - with tf_ops.Graph().as_default(): - # Construct all arguments such that convolution kernel is applied in - # exactly one spatial location. - inputs = np.random.randn( - 1, # batch_size - self.kernel_height, - self.kernel_width, - self.in_channels) # in_channels - outputs_grad = np.random.randn( - 1, # batch_size - 1, # output_height - 1, # output_width - self.out_channels) - - factor = ff.ConvDiagonalFactor( - (constant_op.constant(inputs),), - ((constant_op.constant(outputs_grad),),), - self.kernel_shape, - strides=[1, 1, 1, 1], - padding='VALID') - factor.instantiate_cov_variables() - - # Completely forget initial value on first update. - cov_update_op = factor.make_covariance_update_op(0.0) - - # Ensure new covariance value is same as outer-product of inputs/outputs - # vectorized, squared. - with self.cached_session() as sess: - sess.run(tf_variables.global_variables_initializer()) - cov = sess.run(cov_update_op) - expected_cov = np.outer(inputs.flatten(), outputs_grad.flatten())**2 - self.assertAllClose(expected_cov, cov) - - def testHasBias(self): - with tf_ops.Graph().as_default(): - inputs = random_ops.random_uniform( - [self.batch_size, self.height, self.width, self.in_channels]) - outputs_grads = [ - random_ops.random_uniform([ - self.batch_size, self.height // self.strides[1], - self.width // self.strides[2], self.out_channels - ]) for _ in range(3) - ] - - factor = ff.ConvDiagonalFactor( - (inputs,), - (outputs_grads,), - self.kernel_shape, - self.strides, - self.padding, - data_format=self.data_format, - has_bias=True) - factor.instantiate_cov_variables() - - # Ensure shape accounts for bias. - self.assertEqual([ - self.kernel_height * self.kernel_width * self.in_channels + 1, - self.out_channels - ], - factor.get_cov().shape.as_list()) - - # Ensure update op doesn't crash. - cov_update_op = factor.make_covariance_update_op(0.0) - with self.cached_session() as sess: - sess.run(tf_variables.global_variables_initializer()) - sess.run(cov_update_op) - - -class FullyConnectedKroneckerFactorTest(test.TestCase): - - def _testFullyConnectedKroneckerFactorInit(self, - has_bias, - final_shape, - dtype=dtypes.float32_ref): - with tf_ops.Graph().as_default(): - random_seed.set_random_seed(200) - tensor = array_ops.ones((2, 3), dtype=dtype, name='a/b/c') - factor = ff.FullyConnectedKroneckerFactor(((tensor,),), has_bias=has_bias) - factor.instantiate_cov_variables() - cov = factor.get_cov() - self.assertEqual(cov.dtype, dtype) - self.assertEqual(final_shape, cov.get_shape().as_list()) - - def testFullyConnectedKroneckerFactorInitNoBias(self): - for dtype in (dtypes.float32_ref, dtypes.float64_ref): - self._testFullyConnectedKroneckerFactorInit(False, [3, 3], dtype=dtype) - - def testFullyConnectedKroneckerFactorInitWithBias(self): - for dtype in (dtypes.float32_ref, dtypes.float64_ref): - self._testFullyConnectedKroneckerFactorInit(True, [4, 4], dtype=dtype) - - def testMakeCovarianceUpdateOpWithBias(self): - with tf_ops.Graph().as_default(), self.cached_session() as sess: - random_seed.set_random_seed(200) - tensor = array_ops.constant([[1., 2.], [3., 4.]], name='a/b/c') - factor = ff.FullyConnectedKroneckerFactor(((tensor,),), has_bias=True) - factor.instantiate_cov_variables() - - sess.run(tf_variables.global_variables_initializer()) - new_cov = sess.run(factor.make_covariance_update_op(.5)) - self.assertAllClose([[3, 3.5, 1], [3.5, 5.5, 1.5], [1, 1.5, 1]], new_cov) - - def testMakeCovarianceUpdateOpNoBias(self): - with tf_ops.Graph().as_default(), self.cached_session() as sess: - random_seed.set_random_seed(200) - tensor = array_ops.constant([[1., 2.], [3., 4.]], name='a/b/c') - factor = ff.FullyConnectedKroneckerFactor(((tensor,),)) - factor.instantiate_cov_variables() - - sess.run(tf_variables.global_variables_initializer()) - new_cov = sess.run(factor.make_covariance_update_op(.5)) - self.assertAllClose([[3, 3.5], [3.5, 5.5]], new_cov) - - -class ConvFactorTestCase(test.TestCase): - - def assertMatrixRank(self, rank, matrix, atol=1e-5): - assert rank <= matrix.shape[0], 'Rank cannot be larger than matrix size.' - eigvals = np.linalg.eigvals(matrix) - nnz_eigvals = np.sum(eigvals > atol) - self.assertEqual( - rank, - nnz_eigvals, - msg=('Found %d of %d expected non-zero eigenvalues: %s.' % - (nnz_eigvals, rank, eigvals))) - - -class ConvInputKroneckerFactorTest(ConvFactorTestCase): - - def test3DConvolution(self): - with tf_ops.Graph().as_default(): - batch_size = 1 - width = 3 - in_channels = 3**3 - out_channels = 4 - - factor = ff.ConvInputKroneckerFactor( - inputs=(random_ops.random_uniform( - (batch_size, width, width, width, in_channels), seed=0),), - filter_shape=(width, width, width, in_channels, out_channels), - padding='SAME', - strides=(2, 2, 2), - extract_patches_fn='extract_convolution_patches', - has_bias=False) - factor.instantiate_cov_variables() - - # Ensure shape of covariance matches input size of filter. - input_size = in_channels * (width**3) - self.assertEqual([input_size, input_size], - factor.get_cov().shape.as_list()) - - # Ensure cov_update_op doesn't crash. - with self.cached_session() as sess: - sess.run(tf_variables.global_variables_initializer()) - sess.run(factor.make_covariance_update_op(0.0)) - cov = sess.run(factor.get_cov()) - - # Cov should be rank-8, as the filter will be applied at each corner of - # the 4-D cube. - self.assertMatrixRank(8, cov) - - def testPointwiseConv2d(self): - with tf_ops.Graph().as_default(): - batch_size = 1 - width = 3 - in_channels = 3**2 - out_channels = 4 - - factor = ff.ConvInputKroneckerFactor( - inputs=(random_ops.random_uniform( - (batch_size, width, width, in_channels), seed=0),), - filter_shape=(1, 1, in_channels, out_channels), - padding='SAME', - strides=(1, 1, 1, 1), - extract_patches_fn='extract_pointwise_conv2d_patches', - has_bias=False) - factor.instantiate_cov_variables() - - # Ensure shape of covariance matches input size of filter. - self.assertEqual([in_channels, in_channels], - factor.get_cov().shape.as_list()) - - # Ensure cov_update_op doesn't crash. - with self.cached_session() as sess: - sess.run(tf_variables.global_variables_initializer()) - sess.run(factor.make_covariance_update_op(0.0)) - cov = sess.run(factor.get_cov()) - - # Cov should be rank-9, as the filter will be applied at each location. - self.assertMatrixRank(9, cov) - - def testStrides(self): - with tf_ops.Graph().as_default(): - batch_size = 1 - width = 3 - in_channels = 3**2 - out_channels = 4 - - factor = ff.ConvInputKroneckerFactor( - inputs=(random_ops.random_uniform( - (batch_size, width, width, in_channels), seed=0),), - filter_shape=(1, 1, in_channels, out_channels), - padding='SAME', - strides=(1, 2, 1, 1), - extract_patches_fn='extract_image_patches', - has_bias=False) - factor.instantiate_cov_variables() - - with self.cached_session() as sess: - sess.run(tf_variables.global_variables_initializer()) - sess.run(factor.make_covariance_update_op(0.0)) - cov = sess.run(factor.get_cov()) - - # Cov should be the sum of 3 * 2 = 6 outer products. - self.assertMatrixRank(6, cov) - - def testDilationRate(self): - with tf_ops.Graph().as_default(): - batch_size = 1 - width = 3 - in_channels = 2 - out_channels = 4 - - factor = ff.ConvInputKroneckerFactor( - inputs=(random_ops.random_uniform( - (batch_size, width, width, in_channels), seed=0),), - filter_shape=(3, 3, in_channels, out_channels), - padding='SAME', - extract_patches_fn='extract_image_patches', - strides=(1, 1, 1, 1), - dilation_rate=(1, width, width, 1), - has_bias=False) - factor.instantiate_cov_variables() - - with self.cached_session() as sess: - sess.run(tf_variables.global_variables_initializer()) - sess.run(factor.make_covariance_update_op(0.0)) - cov = sess.run(factor.get_cov()) - - # Cov should be rank = in_channels, as only the center of the filter - # receives non-zero input for each input channel. - self.assertMatrixRank(in_channels, cov) - - def testConvInputKroneckerFactorInitNoBias(self): - with tf_ops.Graph().as_default(): - tensor = array_ops.ones((64, 1, 2, 3), name='a/b/c') - factor = ff.ConvInputKroneckerFactor( - inputs=(tensor,), - filter_shape=(1, 2, 3, 4), - padding='SAME', - has_bias=False) - factor.instantiate_cov_variables() - self.assertEqual([1 * 2 * 3, 1 * 2 * 3], - factor.get_cov().get_shape().as_list()) - - def testConvInputKroneckerFactorInit(self): - with tf_ops.Graph().as_default(): - tensor = array_ops.ones((64, 1, 2, 3), name='a/b/c') - factor = ff.ConvInputKroneckerFactor( - (tensor,), filter_shape=(1, 2, 3, 4), padding='SAME', has_bias=True) - factor.instantiate_cov_variables() - self.assertEqual([1 * 2 * 3 + 1, 1 * 2 * 3 + 1], - factor.get_cov().get_shape().as_list()) - - def testConvInputKroneckerFactorInitFloat64(self): - with tf_ops.Graph().as_default(): - dtype = dtypes.float64_ref - tensor = array_ops.ones((64, 1, 2, 3), name='a/b/c', dtype=dtypes.float64) - factor = ff.ConvInputKroneckerFactor( - (tensor,), filter_shape=(1, 2, 3, 4), padding='SAME', has_bias=True) - factor.instantiate_cov_variables() - cov = factor.get_cov() - self.assertEqual(cov.dtype, dtype) - self.assertEqual([1 * 2 * 3 + 1, 1 * 2 * 3 + 1], - cov.get_shape().as_list()) - - def testMakeCovarianceUpdateOpWithBias(self): - with tf_ops.Graph().as_default(), self.cached_session() as sess: - input_shape = (2, 1, 1, 1) - tensor = array_ops.constant( - np.arange(1, 1 + np.prod(input_shape)).reshape(input_shape).astype( - np.float32)) - factor = ff.ConvInputKroneckerFactor( - (tensor,), filter_shape=(1, 1, 1, 1), padding='SAME', has_bias=True) - factor.instantiate_cov_variables() - - sess.run(tf_variables.global_variables_initializer()) - new_cov = sess.run(factor.make_covariance_update_op(0.)) - self.assertAllClose( - [ - [(1. + 4.) / 2., (1. + 2.) / 2.], # - [(1. + 2.) / 2., (1. + 1.) / 2.] - ], # - new_cov) - - def testMakeCovarianceUpdateOpNoBias(self): - with tf_ops.Graph().as_default(), self.cached_session() as sess: - input_shape = (2, 1, 1, 1) - tensor = array_ops.constant( - np.arange(1, 1 + np.prod(input_shape)).reshape(input_shape).astype( - np.float32)) - factor = ff.ConvInputKroneckerFactor( - (tensor,), filter_shape=(1, 1, 1, 1), padding='SAME') - factor.instantiate_cov_variables() - - sess.run(tf_variables.global_variables_initializer()) - new_cov = sess.run(factor.make_covariance_update_op(0.)) - self.assertAllClose([[(1. + 4.) / 2.]], new_cov) - - def testSubSample(self): - with tf_ops.Graph().as_default(): - patches_1 = array_ops.constant(1, shape=(10, 2)) - patches_2 = array_ops.constant(1, shape=(10, 8)) - patches_3 = array_ops.constant(1, shape=(3, 3)) - patches_1_sub = ff._subsample_for_cov_computation(patches_1) - patches_2_sub = ff._subsample_for_cov_computation(patches_2) - patches_3_sub = ff._subsample_for_cov_computation(patches_3) - patches_1_sub_batch_size = patches_1_sub.shape.as_list()[0] - patches_2_sub_batch_size = patches_2_sub.shape.as_list()[0] - patches_3_sub_batch_size = patches_3_sub.shape.as_list()[0] - self.assertEqual(2, patches_1_sub_batch_size) - self.assertEqual(8, patches_2_sub_batch_size) - self.assertEqual(3, patches_3_sub_batch_size) - - -class ConvOutputKroneckerFactorTest(ConvFactorTestCase): - - def test3DConvolution(self): - with tf_ops.Graph().as_default(): - batch_size = 1 - width = 3 - out_channels = width**3 - - factor = ff.ConvOutputKroneckerFactor(outputs_grads=([ - random_ops.random_uniform( - (batch_size, width, width, width, out_channels), seed=0) - ],)) - factor.instantiate_cov_variables() - - with self.cached_session() as sess: - sess.run(tf_variables.global_variables_initializer()) - sess.run(factor.make_covariance_update_op(0.0)) - cov = sess.run(factor.get_cov()) - - # Cov should be rank 3^3, as each spatial position donates a rank-1 - # update. - self.assertMatrixRank(width**3, cov) - - def testConvOutputKroneckerFactorInit(self): - with tf_ops.Graph().as_default(): - random_seed.set_random_seed(200) - tensor = array_ops.ones((2, 3, 4, 5), name='a/b/c') - factor = ff.ConvOutputKroneckerFactor(((tensor,),)) - factor.instantiate_cov_variables() - self.assertEqual([5, 5], factor.get_cov().get_shape().as_list()) - - def testConvOutputKroneckerFactorInitFloat64(self): - with tf_ops.Graph().as_default(): - dtype = dtypes.float64_ref - random_seed.set_random_seed(200) - tensor = array_ops.ones((2, 3, 4, 5), dtype=dtype, name='a/b/c') - factor = ff.ConvOutputKroneckerFactor(((tensor,),)) - factor.instantiate_cov_variables() - cov = factor.get_cov() - self.assertEqual(cov.dtype, dtype) - self.assertEqual([5, 5], cov.get_shape().as_list()) - - def testMakeCovarianceUpdateOp(self): - with tf_ops.Graph().as_default(), self.cached_session() as sess: - random_seed.set_random_seed(200) - tensor = np.arange(1, 17).reshape(2, 2, 2, 2).astype(np.float32) - factor = ff.ConvOutputKroneckerFactor(((array_ops.constant(tensor),),)) - factor.instantiate_cov_variables() - - sess.run(tf_variables.global_variables_initializer()) - new_cov = sess.run(factor.make_covariance_update_op(.5)) - self.assertAllClose([[43, 46.5], [46.5, 51.5]], new_cov) - - -class FullyConnectedMultiKFTest(test.TestCase): - - def testFullyConnectedMultiKFInit(self): - with tf_ops.Graph().as_default(): - random_seed.set_random_seed(200) - tensor = array_ops.ones((2, 3), name='a/b/c') - factor = ff.FullyConnectedMultiKF(((tensor,),), has_bias=False) - factor.instantiate_cov_variables() - self.assertEqual([3, 3], factor.get_cov().get_shape().as_list()) - - def testFullyConnectedMultiKFInitFloat64(self): - with tf_ops.Graph().as_default(): - dtype = dtypes.float64_ref - random_seed.set_random_seed(200) - tensor = array_ops.ones((2, 3), dtype=dtype, name='a/b/c') - factor = ff.FullyConnectedMultiKF(((tensor,),), has_bias=False) - factor.instantiate_cov_variables() - cov = factor.get_cov() - self.assertEqual(cov.dtype, dtype) - self.assertEqual([3, 3], cov.get_shape().as_list()) - - def testMakeCovarianceUpdateOpWithBias(self): - with tf_ops.Graph().as_default(), self.cached_session() as sess: - random_seed.set_random_seed(200) - tensor = array_ops.constant([[1., 2.], [3., 4.]], name='a/b/c') - factor = ff.FullyConnectedMultiKF(((tensor,),), has_bias=True) - factor.instantiate_cov_variables() - - sess.run(tf_variables.global_variables_initializer()) - new_cov = sess.run(factor.make_covariance_update_op(.5)) - self.assertAllClose([[3, 3.5, 1], [3.5, 5.5, 1.5], [1, 1.5, 1]], new_cov) - - def testMakeCovarianceUpdateOpNoBias(self): - with tf_ops.Graph().as_default(), self.cached_session() as sess: - random_seed.set_random_seed(200) - tensor = array_ops.constant([[1., 2.], [3., 4.]], name='a/b/c') - factor = ff.FullyConnectedMultiKF(((tensor,),)) - factor.instantiate_cov_variables() - - sess.run(tf_variables.global_variables_initializer()) - new_cov = sess.run(factor.make_covariance_update_op(.5)) - self.assertAllClose([[3, 3.5], [3.5, 5.5]], new_cov) - - -if __name__ == '__main__': - test.main() diff --git a/tensorflow/contrib/kfac/python/kernel_tests/layer_collection_test.py b/tensorflow/contrib/kfac/python/kernel_tests/layer_collection_test.py deleted file mode 100644 index 586fcd4c3c..0000000000 --- a/tensorflow/contrib/kfac/python/kernel_tests/layer_collection_test.py +++ /dev/null @@ -1,597 +0,0 @@ -# Copyright 2017 The TensorFlow Authors. All Rights Reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -# ============================================================================== -"""Tests for tf.contrib.kfac.layer_collection.""" - -from __future__ import absolute_import -from __future__ import division -from __future__ import print_function - -from tensorflow.contrib.kfac.python.ops import fisher_blocks -from tensorflow.contrib.kfac.python.ops import fisher_factors -from tensorflow.contrib.kfac.python.ops import layer_collection -from tensorflow.python.framework import dtypes -from tensorflow.python.framework import ops -from tensorflow.python.framework import random_seed -from tensorflow.python.ops import array_ops -from tensorflow.python.ops import linalg_ops -from tensorflow.python.ops import math_ops -from tensorflow.python.ops import random_ops -from tensorflow.python.ops import variable_scope -from tensorflow.python.platform import test - - -class MockFisherBlock(object): - """A fake FisherBlock.""" - - num_registered_towers = 2 - - def __init__(self, name='MockFisherBlock'): - self.name = name - - def __eq__(self, other): - return isinstance(other, MockFisherBlock) and other.name == self.name - - def __hash__(self): - return hash(self.name) - - -class LayerParametersDictTest(test.TestCase): - - def testSetItem(self): - """Ensure insertion, contains, retrieval works for supported key types.""" - with ops.Graph().as_default(): - lp_dict = layer_collection.LayerParametersDict() - - x = array_ops.constant(0) - y0 = array_ops.constant(0) - y1 = array_ops.constant(0) - z0 = array_ops.constant(0) - z1 = array_ops.constant(0) - keys = [x, (y0, y1), [z0, z1]] - for key in keys: - lp_dict[key] = key - - for key in keys: - self.assertTrue(key in lp_dict) - self.assertEqual(lp_dict[key], key) - - def testSetItemOverlap(self): - """Ensure insertion fails if key overlaps with existing key.""" - with ops.Graph().as_default(): - lp_dict = layer_collection.LayerParametersDict() - - x = array_ops.constant(0) - y = array_ops.constant(0) - lp_dict[x] = 'value' - - with self.assertRaises(ValueError): - lp_dict[(x, y)] = 'value' - - # Ensure 'y' wasn't inserted. - self.assertTrue(x in lp_dict) - self.assertFalse(y in lp_dict) - - -class LayerCollectionTest(test.TestCase): - - def testLayerCollectionInit(self): - lc = layer_collection.LayerCollection() - self.assertEqual(0, len(lc.get_blocks())) - self.assertEqual(0, len(lc.get_factors())) - self.assertFalse(lc.losses) - - def testRegisterBlocks(self): - with ops.Graph().as_default(): - random_seed.set_random_seed(200) - lc = layer_collection.LayerCollection() - lc.register_fully_connected( - array_ops.constant(1), array_ops.constant(2), array_ops.constant(3)) - lc.register_fully_connected( - array_ops.constant(1), - array_ops.constant(2), - array_ops.constant(3), - approx=layer_collection.APPROX_DIAGONAL_NAME) - lc.register_conv2d( - params=array_ops.ones((2, 3, 4, 5)), - strides=[1, 1, 1, 1], - padding='SAME', - inputs=array_ops.ones((1, 2, 3, 4)), - outputs=array_ops.ones((1, 1, 1, 5))) - lc.register_conv2d( - params=array_ops.ones((2, 3, 4, 5)), - strides=[1, 1, 1, 1], - padding='SAME', - inputs=array_ops.ones((1, 2, 3, 4)), - outputs=array_ops.ones((1, 1, 1, 5)), - approx=layer_collection.APPROX_DIAGONAL_NAME) - lc.register_separable_conv2d( - depthwise_params=array_ops.ones((3, 3, 1, 2)), - pointwise_params=array_ops.ones((1, 1, 2, 4)), - inputs=array_ops.ones((32, 5, 5, 1)), - depthwise_outputs=array_ops.ones((32, 5, 5, 2)), - pointwise_outputs=array_ops.ones((32, 5, 5, 4)), - strides=[1, 1, 1, 1], - padding='SAME') - lc.register_convolution( - params=array_ops.ones((3, 3, 1, 8)), - inputs=array_ops.ones((32, 5, 5, 1)), - outputs=array_ops.ones((32, 5, 5, 8)), - padding='SAME') - lc.register_generic( - array_ops.constant(5), 16, approx=layer_collection.APPROX_FULL_NAME) - lc.register_generic( - array_ops.constant(6), - 16, - approx=layer_collection.APPROX_DIAGONAL_NAME) - lc.register_fully_connected_multi( - array_ops.constant(1), - (array_ops.constant(2), array_ops.constant(3)), - (array_ops.constant(4), array_ops.constant(5))) - lc.register_conv2d_multi( - params=array_ops.ones((2, 3, 4, 5)), - strides=[1, 1, 1, 1], - padding='SAME', - inputs=(array_ops.ones((1, 2, 3, 4)), array_ops.ones((5, 6, 7, 8))), - outputs=(array_ops.ones((1, 1, 1, 5)), array_ops.ones((2, 2, 2, 10)))) - lc.register_embedding_multi( - array_ops.constant((1,)), - (array_ops.constant(2), array_ops.constant(3)), - (array_ops.constant(4), array_ops.constant(5))) - - self.assertEqual(12, len(lc.get_blocks())) - - def testRegisterBlocksMultipleRegistrations(self): - with ops.Graph().as_default(): - random_seed.set_random_seed(200) - lc = layer_collection.LayerCollection() - key = array_ops.constant(1) - lc.register_fully_connected(key, array_ops.constant(2), - array_ops.constant(3)) - with self.assertRaises(ValueError) as cm: - lc.register_generic(key, 16) - self.assertIn('already in LayerCollection', str(cm.exception)) - - def testRegisterSingleParamNotRegistered(self): - x = variable_scope.get_variable('x', initializer=array_ops.constant(1,)) - lc = layer_collection.LayerCollection() - lc.fisher_blocks = { - variable_scope.get_variable('y', initializer=array_ops.constant(1,)): - '1' - } - lc.register_block(x, 'foo') - - def testShouldRegisterSingleParamRegistered(self): - x = variable_scope.get_variable('x', initializer=array_ops.constant(1,)) - lc = layer_collection.LayerCollection() - lc.fisher_blocks = {x: '1'} - with self.assertRaises(ValueError) as cm: - lc.register_block(x, 'foo') - self.assertIn('already in LayerCollection', str(cm.exception)) - - def testRegisterSingleParamRegisteredInTuple(self): - x = variable_scope.get_variable('x', initializer=array_ops.constant(1,)) - y = variable_scope.get_variable('y', initializer=array_ops.constant(1,)) - lc = layer_collection.LayerCollection() - lc.fisher_blocks = {(x, y): '1'} - with self.assertRaises(ValueError) as cm: - lc.register_block(x, 'foo') - self.assertIn('was already registered', str(cm.exception)) - - def testRegisterTupleParamNotRegistered(self): - x = variable_scope.get_variable('x', initializer=array_ops.constant(1,)) - y = variable_scope.get_variable('y', initializer=array_ops.constant(1,)) - lc = layer_collection.LayerCollection() - lc.fisher_blocks = { - variable_scope.get_variable('z', initializer=array_ops.constant(1,)): - '1' - } - - lc.register_block((x, y), 'foo') - self.assertEqual(set(['1', 'foo']), set(lc.get_blocks())) - - def testRegisterTupleParamRegistered(self): - x = variable_scope.get_variable('x', initializer=array_ops.constant(1,)) - y = variable_scope.get_variable('y', initializer=array_ops.constant(1,)) - lc = layer_collection.LayerCollection() - lc.fisher_blocks = {(x, y): '1'} - - with self.assertRaises(ValueError) as cm: - lc.register_block((x, y), 'foo') - self.assertIn('already in LayerCollection', str(cm.exception)) - - def testRegisterTupleParamRegisteredInSuperset(self): - x = variable_scope.get_variable('x', initializer=array_ops.constant(1,)) - y = variable_scope.get_variable('y', initializer=array_ops.constant(1,)) - z = variable_scope.get_variable('z', initializer=array_ops.constant(1,)) - lc = layer_collection.LayerCollection() - lc.fisher_blocks = {(x, y, z): '1'} - - with self.assertRaises(ValueError) as cm: - lc.register_block((x, y), 'foo') - self.assertIn('was already registered', str(cm.exception)) - - def testRegisterTupleParamSomeRegistered(self): - x = variable_scope.get_variable('x', initializer=array_ops.constant(1,)) - y = variable_scope.get_variable('y', initializer=array_ops.constant(1,)) - z = variable_scope.get_variable('z', initializer=array_ops.constant(1,)) - lc = layer_collection.LayerCollection() - lc.fisher_blocks = {x: MockFisherBlock('1'), z: MockFisherBlock('2')} - - with self.assertRaises(ValueError) as cm: - lc.register_block((x, y), MockFisherBlock('foo')) - self.assertIn('was already registered', str(cm.exception)) - - def testRegisterTupleVarSomeRegisteredInOtherTuples(self): - x = variable_scope.get_variable('x', initializer=array_ops.constant(1,)) - y = variable_scope.get_variable('y', initializer=array_ops.constant(1,)) - z = variable_scope.get_variable('z', initializer=array_ops.constant(1,)) - w = variable_scope.get_variable('w', initializer=array_ops.constant(1,)) - lc = layer_collection.LayerCollection() - lc.fisher_blocks = {(x, z): '1', (z, w): '2'} - - with self.assertRaises(ValueError) as cm: - lc.register_block((x, y), 'foo') - self.assertIn('was already registered', str(cm.exception)) - - def testRegisterCategoricalPredictiveDistribution(self): - with ops.Graph().as_default(), self.cached_session() as sess: - random_seed.set_random_seed(200) - logits = linalg_ops.eye(2) - - lc = layer_collection.LayerCollection() - lc.register_categorical_predictive_distribution(logits, seed=200) - single_loss = sess.run(lc.total_sampled_loss()) - - lc2 = layer_collection.LayerCollection() - lc2.register_categorical_predictive_distribution(logits, seed=200) - lc2.register_categorical_predictive_distribution(logits, seed=200) - double_loss = sess.run(lc2.total_sampled_loss()) - self.assertAlmostEqual(2 * single_loss, double_loss) - - def testLossFunctionByName(self): - """Ensure loss functions can be identified by name.""" - with ops.Graph().as_default(): - logits = linalg_ops.eye(2) - lc = layer_collection.LayerCollection() - - # Create a new loss function by name. - lc.register_categorical_predictive_distribution(logits, name='loss1') - self.assertEqual(1, len(lc.towers_by_loss)) - - # Add logits to same loss function. - lc.register_categorical_predictive_distribution( - logits, name='loss1', reuse=True) - self.assertEqual(1, len(lc.towers_by_loss)) - - # Add another new loss function. - lc.register_categorical_predictive_distribution(logits, name='loss2') - self.assertEqual(2, len(lc.towers_by_loss)) - - def testLossFunctionWithoutName(self): - """Ensure loss functions get unique names if 'name' not specified.""" - with ops.Graph().as_default(): - logits = linalg_ops.eye(2) - lc = layer_collection.LayerCollection() - - # Create a new loss function with default names. - lc.register_categorical_predictive_distribution(logits) - lc.register_categorical_predictive_distribution(logits) - self.assertEqual(2, len(lc.losses)) - - def testCategoricalPredictiveDistributionMultipleMinibatches(self): - """Ensure multiple minibatches are registered.""" - with ops.Graph().as_default(): - batch_size = 3 - output_size = 2 - logits = array_ops.zeros([batch_size, output_size]) - targets = array_ops.ones([batch_size], dtype=dtypes.int32) - lc = layer_collection.LayerCollection() - - # Create a new loss function. - lc.register_categorical_predictive_distribution( - logits, targets=targets, name='loss1') - - # Can add when reuse=True - lc.register_categorical_predictive_distribution( - logits, targets=targets, name='loss1', reuse=True) - - # Can add when reuse=VARIABLE_SCOPE and reuse=True there. - with variable_scope.variable_scope( - variable_scope.get_variable_scope(), reuse=True): - lc.register_categorical_predictive_distribution( - logits, - targets=targets, - name='loss1', - reuse=layer_collection.VARIABLE_SCOPE) - - # Can't add when reuse=False - with self.assertRaises(KeyError): - lc.register_categorical_predictive_distribution( - logits, targets=targets, name='loss1', reuse=False) - - # Can't add when reuse=VARIABLE_SCOPE and reuse=False there. - with self.assertRaises(KeyError): - lc.register_categorical_predictive_distribution( - logits, - targets=targets, - name='loss1', - reuse=layer_collection.VARIABLE_SCOPE) - - self.assertEqual(len(lc.towers_by_loss), 1) - # Three successful registrations. - self.assertEqual(len(lc.towers_by_loss[0]), 3) - - def testRegisterCategoricalPredictiveDistributionBatchSize1(self): - with ops.Graph().as_default(): - random_seed.set_random_seed(200) - logits = random_ops.random_normal((1, 2)) - lc = layer_collection.LayerCollection() - - lc.register_categorical_predictive_distribution(logits, seed=200) - - def testRegisterCategoricalPredictiveDistributionSpecifiedTargets(self): - with ops.Graph().as_default(), self.cached_session() as sess: - random_seed.set_random_seed(200) - logits = array_ops.constant([[1., 2.], [3., 4.]], dtype=dtypes.float32) - lc = layer_collection.LayerCollection() - targets = array_ops.constant([0, 1], dtype=dtypes.int32) - - lc.register_categorical_predictive_distribution(logits, targets=targets) - single_loss = sess.run(lc.total_loss()) - self.assertAlmostEqual(1.6265233, single_loss) - - def testRegisterNormalPredictiveDistribution(self): - with ops.Graph().as_default(), self.cached_session() as sess: - random_seed.set_random_seed(200) - predictions = array_ops.constant( - [[1., 2.], [3., 4]], dtype=dtypes.float32) - - lc = layer_collection.LayerCollection() - lc.register_normal_predictive_distribution(predictions, 1., seed=200) - single_loss = sess.run(lc.total_sampled_loss()) - - lc2 = layer_collection.LayerCollection() - lc2.register_normal_predictive_distribution(predictions, 1., seed=200) - lc2.register_normal_predictive_distribution(predictions, 1., seed=200) - double_loss = sess.run(lc2.total_sampled_loss()) - - self.assertAlmostEqual(2 * single_loss, double_loss) - - def testRegisterNormalPredictiveDistributionSpecifiedTargets(self): - with ops.Graph().as_default(), self.cached_session() as sess: - random_seed.set_random_seed(200) - predictions = array_ops.constant( - [[1., 2.], [3., 4.]], dtype=dtypes.float32) - lc = layer_collection.LayerCollection() - targets = array_ops.constant([[3., 1.], [4., 2.]], dtype=dtypes.float32) - - lc.register_normal_predictive_distribution( - predictions, 2.**2, targets=targets) - single_loss = sess.run(lc.total_loss()) - self.assertAlmostEqual(7.6983433, single_loss) - - def ensureLayerReuseWorks(self, register_fn): - """Ensure the 'reuse' keyword argument function as intended. - - Args: - register_fn: function for registering a layer. Arguments are - layer_collection, reuse, and approx. - """ - # Fails on second if reuse=False. - lc = layer_collection.LayerCollection() - register_fn(lc) - with self.assertRaises(ValueError): - register_fn(lc, reuse=False) - - # Succeeds on second if reuse=True. - lc = layer_collection.LayerCollection() - register_fn(lc) - register_fn(lc, reuse=True) - - # Fails on second if reuse=VARIABLE_SCOPE and no variable reuse. - lc = layer_collection.LayerCollection() - register_fn(lc) - with self.assertRaises(ValueError): - register_fn(lc, reuse=layer_collection.VARIABLE_SCOPE) - - # Succeeds on second if reuse=VARIABLE_SCOPE and variable reuse. - lc = layer_collection.LayerCollection() - register_fn(lc) - with variable_scope.variable_scope( - variable_scope.get_variable_scope(), reuse=True): - register_fn(lc, reuse=layer_collection.VARIABLE_SCOPE) - - # Fails if block type changes. - lc = layer_collection.LayerCollection() - register_fn(lc, approx=layer_collection.APPROX_KRONECKER_NAME) - with self.assertRaises(ValueError): - register_fn(lc, approx=layer_collection.APPROX_DIAGONAL_NAME, reuse=True) - - # Fails if reuse requested but no FisherBlock exists. - lc = layer_collection.LayerCollection() - with self.assertRaises(KeyError): - register_fn(lc, reuse=True) - - def testRegisterFullyConnectedReuse(self): - """Ensure the 'reuse' works with register_fully_connected.""" - with ops.Graph().as_default(): - inputs = array_ops.ones([2, 10]) - outputs = array_ops.zeros([2, 5]) - params = ( - variable_scope.get_variable('w', [10, 5]), # - variable_scope.get_variable('b', [5])) - - def register_fn(lc, **kwargs): - lc.register_fully_connected( - params=params, inputs=inputs, outputs=outputs, **kwargs) - - self.ensureLayerReuseWorks(register_fn) - - def testRegisterConv2dReuse(self): - """Ensure the 'reuse' works with register_conv2d.""" - with ops.Graph().as_default(): - inputs = array_ops.ones([2, 5, 5, 10]) - outputs = array_ops.zeros([2, 5, 5, 3]) - params = ( - variable_scope.get_variable('w', [1, 1, 10, 3]), # - variable_scope.get_variable('b', [3])) - - def register_fn(lc, **kwargs): - lc.register_conv2d( - params=params, - strides=[1, 1, 1, 1], - padding='SAME', - inputs=inputs, - outputs=outputs, - **kwargs) - - self.ensureLayerReuseWorks(register_fn) - - def testReuseWithInvalidRegistration(self): - """Invalid registrations shouldn't overwrite existing blocks.""" - with ops.Graph().as_default(): - inputs = array_ops.ones([2, 5, 5, 10]) - outputs = array_ops.zeros([2, 5, 5, 3]) - w = variable_scope.get_variable('w', [1, 1, 10, 3]) - b = variable_scope.get_variable('b', [3]) - lc = layer_collection.LayerCollection() - lc.register_fully_connected(w, inputs, outputs) - self.assertEqual(lc.fisher_blocks[w].num_registered_towers, 1) - with self.assertRaises(KeyError): - lc.register_fully_connected((w, b), inputs, outputs, reuse=True) - self.assertNotIn((w, b), lc.fisher_blocks) - self.assertEqual(lc.fisher_blocks[w].num_registered_towers, 1) - lc.register_fully_connected(w, inputs, outputs, reuse=True) - self.assertEqual(lc.fisher_blocks[w].num_registered_towers, 2) - - def testMakeOrGetFactor(self): - with ops.Graph().as_default(): - random_seed.set_random_seed(200) - lc = layer_collection.LayerCollection() - key = array_ops.constant(1) - lc.make_or_get_factor(fisher_factors.FullFactor, ((key,), 16)) - lc.make_or_get_factor(fisher_factors.FullFactor, ((key,), 16)) - lc.make_or_get_factor(fisher_factors.FullFactor, - ((array_ops.constant(2),), 16)) - - self.assertEqual(2, len(lc.get_factors())) - variables = ops.get_collection(ops.GraphKeys.GLOBAL_VARIABLES) - self.assertTrue( - all([var.name.startswith('LayerCollection') for var in variables])) - - def testMakeOrGetFactorCustomScope(self): - with ops.Graph().as_default(): - random_seed.set_random_seed(200) - scope = 'Foo' - lc = layer_collection.LayerCollection(name=scope) - key = array_ops.constant(1) - lc.make_or_get_factor(fisher_factors.FullFactor, ((key,), 16)) - lc.make_or_get_factor(fisher_factors.FullFactor, ((key,), 16)) - lc.make_or_get_factor(fisher_factors.FullFactor, - ((array_ops.constant(2),), 16)) - - self.assertEqual(2, len(lc.get_factors())) - variables = ops.get_collection(ops.GraphKeys.GLOBAL_VARIABLES) - self.assertTrue(all([var.name.startswith(scope) for var in variables])) - - def testIdentifyLinkedParametersSomeRegisteredInOtherTuples(self): - x = variable_scope.get_variable('x', shape=()) - y = variable_scope.get_variable('y', shape=()) - z = variable_scope.get_variable('z', shape=()) - lc = layer_collection.LayerCollection() - lc.define_linked_parameters((x, y)) - - with self.assertRaises(ValueError): - lc.define_linked_parameters((x, z)) - - def testIdentifySubsetPreviouslyRegisteredTensor(self): - x = variable_scope.get_variable('x', shape=()) - y = variable_scope.get_variable('y', shape=()) - lc = layer_collection.LayerCollection() - lc.define_linked_parameters((x, y)) - - with self.assertRaises(ValueError): - lc.define_linked_parameters(x) - - def testSpecifyApproximation(self): - w_0 = variable_scope.get_variable('w_0', [10, 10]) - w_1 = variable_scope.get_variable('w_1', [10, 10]) - - b_0 = variable_scope.get_variable('b_0', [10]) - b_1 = variable_scope.get_variable('b_1', [10]) - - x_0 = array_ops.placeholder(dtypes.float32, shape=(32, 10)) - x_1 = array_ops.placeholder(dtypes.float32, shape=(32, 10)) - - pre_bias_0 = math_ops.matmul(x_0, w_0) - pre_bias_1 = math_ops.matmul(x_1, w_1) - - # Build the fully connected layers in the graph. - pre_bias_0 + b_0 # pylint: disable=pointless-statement - pre_bias_1 + b_1 # pylint: disable=pointless-statement - - lc = layer_collection.LayerCollection() - lc.define_linked_parameters( - w_0, approximation=layer_collection.APPROX_DIAGONAL_NAME) - lc.define_linked_parameters( - w_1, approximation=layer_collection.APPROX_DIAGONAL_NAME) - lc.define_linked_parameters( - b_0, approximation=layer_collection.APPROX_FULL_NAME) - lc.define_linked_parameters( - b_1, approximation=layer_collection.APPROX_FULL_NAME) - - lc.register_fully_connected(w_0, x_0, pre_bias_0) - lc.register_fully_connected( - w_1, x_1, pre_bias_1, approx=layer_collection.APPROX_KRONECKER_NAME) - self.assertIsInstance(lc.fisher_blocks[w_0], - fisher_blocks.FullyConnectedDiagonalFB) - self.assertIsInstance(lc.fisher_blocks[w_1], - fisher_blocks.FullyConnectedKFACBasicFB) - - lc.register_generic(b_0, batch_size=1) - lc.register_generic( - b_1, batch_size=1, approx=layer_collection.APPROX_DIAGONAL_NAME) - self.assertIsInstance(lc.fisher_blocks[b_0], fisher_blocks.FullFB) - self.assertIsInstance(lc.fisher_blocks[b_1], fisher_blocks.NaiveDiagonalFB) - - def testDefaultLayerCollection(self): - with ops.Graph().as_default(): - # Can't get default if there isn't one set. - with self.assertRaises(ValueError): - layer_collection.get_default_layer_collection() - - # Can't set default twice. - lc = layer_collection.LayerCollection() - layer_collection.set_default_layer_collection(lc) - with self.assertRaises(ValueError): - layer_collection.set_default_layer_collection(lc) - - # Same as one set. - self.assertTrue(lc is layer_collection.get_default_layer_collection()) - - # Can set to None. - layer_collection.set_default_layer_collection(None) - with self.assertRaises(ValueError): - layer_collection.get_default_layer_collection() - - # as_default() is the same as setting/clearing. - with lc.as_default(): - self.assertTrue(lc is layer_collection.get_default_layer_collection()) - with self.assertRaises(ValueError): - layer_collection.get_default_layer_collection() - - -if __name__ == '__main__': - test.main() diff --git a/tensorflow/contrib/kfac/python/kernel_tests/loss_functions_test.py b/tensorflow/contrib/kfac/python/kernel_tests/loss_functions_test.py deleted file mode 100644 index f424e02360..0000000000 --- a/tensorflow/contrib/kfac/python/kernel_tests/loss_functions_test.py +++ /dev/null @@ -1,190 +0,0 @@ -# Copyright 2017 The TensorFlow Authors. All Rights Reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -# ============================================================================== -"""Tests for tf.contrib.kfac.loss_functions.""" - -from __future__ import absolute_import -from __future__ import division -from __future__ import print_function - -import numpy as np - -from tensorflow.contrib.kfac.python.ops import loss_functions -from tensorflow.python.framework import constant_op -from tensorflow.python.framework import ops -from tensorflow.python.ops import array_ops -from tensorflow.python.platform import test - - -class InsertSliceInZerosTest(test.TestCase): - - def testBadShape(self): - bad_shaped_ones = array_ops.ones(shape=[1, 3]) # n.b. shape[1] != 1 - with self.assertRaises(ValueError): - loss_functions.insert_slice_in_zeros(bad_shaped_ones, 1, 42, 17) - - def test3d(self): - input_tensor = constant_op.constant([[[1, 2]], [[3, 4]]]) - expected_output_array = [[[1, 2], [0, 0]], [[3, 4], [0, 0]]] - op = loss_functions.insert_slice_in_zeros(input_tensor, 1, 2, 0) - with self.cached_session() as sess: - actual_output_array = sess.run(op) - self.assertAllEqual(expected_output_array, actual_output_array) - - -class CategoricalLogitsNegativeLogProbLossTest(test.TestCase): - - def testSample(self): - """Ensure samples can be drawn.""" - with ops.Graph().as_default(), self.cached_session() as sess: - logits = np.asarray([ - [0., 0., 0.], # - [1., -1., 0.] - ]).astype(np.float32) - loss = loss_functions.CategoricalLogitsNegativeLogProbLoss( - array_ops.constant(logits)) - sample = loss.sample(42) - sample = sess.run(sample) - self.assertEqual(sample.shape, (2,)) - - def testEvaluateOnTargets(self): - """Ensure log probability can be evaluated correctly.""" - with ops.Graph().as_default(), self.cached_session() as sess: - logits = np.asarray([ - [0., 0., 0.], # - [1., -1., 0.] - ]).astype(np.float32) - targets = np.asarray([2, 1]).astype(np.int32) - loss = loss_functions.CategoricalLogitsNegativeLogProbLoss( - array_ops.constant(logits), targets=array_ops.constant(targets)) - neg_log_prob = loss.evaluate() - neg_log_prob = sess.run(neg_log_prob) - - # Calculate explicit log probability of targets. - probs = np.exp(logits) / np.sum(np.exp(logits), axis=1, keepdims=True) - log_probs = np.log([ - probs[0, targets[0]], # - probs[1, targets[1]] - ]) - expected_log_prob = np.sum(log_probs) - - self.assertAllClose(neg_log_prob, -expected_log_prob) - - def testEvaluateOnSample(self): - """Ensure log probability of a sample can be drawn.""" - with ops.Graph().as_default(), self.cached_session() as sess: - logits = np.asarray([ - [0., 0., 0.], # - [1., -1., 0.] - ]).astype(np.float32) - loss = loss_functions.CategoricalLogitsNegativeLogProbLoss( - array_ops.constant(logits)) - neg_log_prob = loss.evaluate_on_sample(42) - - # Simply ensure this doesn't crash. As the output is random, it's - # difficult to say if the output is correct or not... - neg_log_prob = sess.run(neg_log_prob) - - def testMultiplyFisherSingleVector(self): - with ops.Graph().as_default(), self.cached_session() as sess: - logits = np.array([1., 2., 3.]) - loss = loss_functions.CategoricalLogitsNegativeLogProbLoss(logits) - - # the LossFunction.multiply_fisher docstring only says it supports the - # case where the vector is the same shape as the input natural parameters - # (i.e. the logits here), but here we also test leading dimensions - vector = np.array([1., 2., 3.]) - vectors = [vector, vector.reshape(1, -1), np.stack([vector] * 4)] - - probs = np.exp(logits - np.logaddexp.reduce(logits)) - fisher = np.diag(probs) - np.outer(probs, probs) - - for vector in vectors: - result = loss.multiply_fisher(vector) - expected_result = np.dot(vector, fisher) - self.assertAllClose(expected_result, sess.run(result)) - - def testMultiplyFisherBatch(self): - with ops.Graph().as_default(), self.cached_session() as sess: - logits = np.array([[1., 2., 3.], [4., 6., 8.]]) - loss = loss_functions.CategoricalLogitsNegativeLogProbLoss(logits) - - vector = np.array([[1., 2., 3.], [5., 3., 1.]]) - - na = np.newaxis - probs = np.exp(logits - np.logaddexp.reduce(logits, axis=-1, - keepdims=True)) - fishers = probs[..., na] * np.eye(3) - probs[..., na] * probs[..., na, :] - - result = loss.multiply_fisher(vector) - expected_result = np.matmul(vector[..., na, :], fishers)[..., 0, :] - self.assertEqual(sess.run(result).shape, logits.shape) - self.assertAllClose(expected_result, sess.run(result)) - - -class OnehotCategoricalLogitsNegativeLogProbLossTest(test.TestCase): - - def testSample(self): - """Ensure samples can be drawn.""" - with ops.Graph().as_default(), self.cached_session() as sess: - logits = np.asarray([ - [0., 0., 0.], # - [1., -1., 0.] - ]).astype(np.float32) - loss = loss_functions.OnehotCategoricalLogitsNegativeLogProbLoss( - array_ops.constant(logits)) - sample = loss.sample(42) - sample = sess.run(sample) - self.assertEqual(sample.shape, (2, 3)) - - def testEvaluateOnTargets(self): - """Ensure log probability can be evaluated correctly.""" - with ops.Graph().as_default(), self.cached_session() as sess: - logits = np.asarray([ - [0., 0., 0.], # - [1., -1., 0.] - ]).astype(np.float32) - targets = np.asarray([2, 1]).astype(np.int32) - loss = loss_functions.OnehotCategoricalLogitsNegativeLogProbLoss( - array_ops.constant(logits), targets=array_ops.one_hot(targets, 3)) - neg_log_prob = loss.evaluate() - neg_log_prob = sess.run(neg_log_prob) - - # Calculate explicit log probability of targets. - probs = np.exp(logits) / np.sum(np.exp(logits), axis=1, keepdims=True) - log_probs = np.log([ - probs[0, targets[0]], # - probs[1, targets[1]] - ]) - expected_log_prob = np.sum(log_probs) - - self.assertAllClose(neg_log_prob, -expected_log_prob) - - def testEvaluateOnSample(self): - """Ensure log probability of a sample can be drawn.""" - with ops.Graph().as_default(), self.cached_session() as sess: - logits = np.asarray([ - [0., 0., 0.], # - [1., -1., 0.] - ]).astype(np.float32) - loss = loss_functions.OnehotCategoricalLogitsNegativeLogProbLoss( - array_ops.constant(logits)) - neg_log_prob = loss.evaluate_on_sample(42) - - # Simply ensure this doesn't crash. As the output is random, it's - # difficult to say if the output is correct or not... - neg_log_prob = sess.run(neg_log_prob) - -if __name__ == "__main__": - test.main() diff --git a/tensorflow/contrib/kfac/python/kernel_tests/op_queue_test.py b/tensorflow/contrib/kfac/python/kernel_tests/op_queue_test.py deleted file mode 100644 index 4fae4374e1..0000000000 --- a/tensorflow/contrib/kfac/python/kernel_tests/op_queue_test.py +++ /dev/null @@ -1,50 +0,0 @@ -# Copyright 2017 The TensorFlow Authors. All Rights Reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -# ============================================================================== -"""Tests for tf.contrib.kfac.op_queue.""" - -from __future__ import absolute_import -from __future__ import division -from __future__ import print_function - -from tensorflow.contrib.kfac.python.ops import op_queue -from tensorflow.python.framework import ops as tf_ops -from tensorflow.python.ops import math_ops -from tensorflow.python.platform import test - - -class OpQueueTest(test.TestCase): - - def testNextOp(self): - """Ensures all ops get selected eventually.""" - with tf_ops.Graph().as_default(): - ops = [ - math_ops.add(1, 2), - math_ops.subtract(1, 2), - math_ops.reduce_mean([1, 2]), - ] - queue = op_queue.OpQueue(ops, seed=0) - - with self.cached_session() as sess: - # Ensure every inv update op gets selected. - selected_ops = set([queue.next_op(sess) for _ in ops]) - self.assertEqual(set(ops), set(selected_ops)) - - # Ensure additional calls don't create any new ops. - selected_ops.add(queue.next_op(sess)) - self.assertEqual(set(ops), set(selected_ops)) - - -if __name__ == "__main__": - test.main() diff --git a/tensorflow/contrib/kfac/python/kernel_tests/optimizer_test.py b/tensorflow/contrib/kfac/python/kernel_tests/optimizer_test.py deleted file mode 100644 index 0b0de12ce6..0000000000 --- a/tensorflow/contrib/kfac/python/kernel_tests/optimizer_test.py +++ /dev/null @@ -1,219 +0,0 @@ -# Copyright 2017 The TensorFlow Authors. All Rights Reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -# ============================================================================== -"""Tests for tf.contrib.kfac.optimizer.""" - -from __future__ import absolute_import -from __future__ import division -from __future__ import print_function - -import numpy as np - -from tensorflow.contrib.kfac.python.ops import fisher_factors as ff -from tensorflow.contrib.kfac.python.ops import layer_collection as lc -from tensorflow.contrib.kfac.python.ops import optimizer -from tensorflow.python.framework import ops -from tensorflow.python.ops import array_ops -from tensorflow.python.ops import init_ops -from tensorflow.python.ops import math_ops -from tensorflow.python.ops import nn -from tensorflow.python.ops import variable_scope -from tensorflow.python.ops import variables as tf_variables -from tensorflow.python.platform import test - - -# We need to set these constants since the numerical values used in the tests -# were chosen when these used to be the defaults. -ff.set_global_constants(init_covariances_at_zero=False, - zero_debias=False, - init_inverses_at_zero=False) - - -def dummy_layer_collection(): - lcoll = lc.LayerCollection() - dummy = array_ops.constant([1., 2.]) - lcoll.register_categorical_predictive_distribution(logits=dummy) - return lcoll - - -class OptimizerTest(test.TestCase): - - def testOptimizerInitInvalidMomentumRegistration(self): - with self.assertRaises(ValueError): - optimizer.KfacOptimizer( - 0.1, 0.2, 0.3, lc.LayerCollection(), momentum_type='foo') - - def testOptimizerInit(self): - with ops.Graph().as_default(): - layer_collection = lc.LayerCollection() - - inputs = array_ops.ones((2, 1)) * 2 - weights_val = np.ones((1, 1), dtype=np.float32) * 3. - weights = variable_scope.get_variable( - 'w', initializer=array_ops.constant(weights_val)) - bias = variable_scope.get_variable( - 'b', initializer=init_ops.zeros_initializer(), shape=(1, 1)) - output = math_ops.matmul(inputs, weights) + bias - - layer_collection.register_fully_connected((weights, bias), inputs, output) - - logits = math_ops.tanh(output) - targets = array_ops.constant([[0.], [1.]]) - output = math_ops.reduce_mean( - nn.softmax_cross_entropy_with_logits(logits=logits, labels=targets)) - - layer_collection.register_categorical_predictive_distribution(logits) - - optimizer.KfacOptimizer( - 0.1, - 0.2, - 0.3, - layer_collection, - momentum=0.5, - momentum_type='regular') - - def testSquaredFisherNorm(self): - with ops.Graph().as_default(), self.cached_session() as sess: - grads_and_vars = [(array_ops.constant([[1., 2.], [3., 4.]]), None), - (array_ops.constant([[2., 3.], [4., 5.]]), None)] - pgrads_and_vars = [(array_ops.constant([[3., 4.], [5., 6.]]), None), - (array_ops.constant([[7., 8.], [9., 10.]]), None)] - opt = optimizer.KfacOptimizer(0.1, 0.2, 0.3, dummy_layer_collection()) - sq_norm = opt._squared_fisher_norm(grads_and_vars, pgrads_and_vars) - self.assertAlmostEqual(174., sess.run(sq_norm), places=5) - - def testUpdateClipCoeff(self): - with ops.Graph().as_default(), self.cached_session() as sess: - grads_and_vars = [(array_ops.constant([[1., 2.], [3., 4.]]), None), - (array_ops.constant([[2., 3.], [4., 5.]]), None)] - pgrads_and_vars = [(array_ops.constant([[3., 4.], [5., 6.]]), None), - (array_ops.constant([[7., 8.], [9., 10.]]), None)] - lrate = 0.1 - - # Note: without rescaling, the squared Fisher norm of the update - # is 1.74 - - # If the update already satisfies the norm constraint, there should - # be no rescaling. - opt = optimizer.KfacOptimizer( - lrate, 0.2, 0.3, dummy_layer_collection(), norm_constraint=10.) - coeff = opt._update_clip_coeff(grads_and_vars, pgrads_and_vars) - self.assertAlmostEqual(1., sess.run(coeff), places=5) - - # If the update violates the constraint, it should be rescaled to - # be on the constraint boundary. - opt = optimizer.KfacOptimizer( - lrate, 0.2, 0.3, dummy_layer_collection(), norm_constraint=0.5) - coeff = opt._update_clip_coeff(grads_and_vars, pgrads_and_vars) - sq_norm_pgrad = opt._squared_fisher_norm(grads_and_vars, pgrads_and_vars) - sq_norm_update = lrate**2 * coeff**2 * sq_norm_pgrad - self.assertAlmostEqual(0.5, sess.run(sq_norm_update), places=5) - - def testComputeUpdateStepsRegular(self): - # TODO(olganw): implement this. - pass - - def testComputeUpdateStepsAdam(self): - # TODO(olganw): implement this. - pass - - def testUpdateVelocities(self): - with ops.Graph().as_default(), self.cached_session() as sess: - layers = lc.LayerCollection() - layers.register_categorical_predictive_distribution( - array_ops.constant([1.0])) - opt = optimizer.KfacOptimizer( - 0.1, 0.2, 0.3, layers, momentum=0.5, momentum_type='regular') - x = variable_scope.get_variable('x', initializer=array_ops.ones((2, 2))) - y = variable_scope.get_variable( - 'y', initializer=array_ops.ones((2, 2)) * 2) - vec1 = array_ops.ones((2, 2)) * 3 - vec2 = array_ops.ones((2, 2)) * 4 - - model_vars = ops.get_collection(ops.GraphKeys.GLOBAL_VARIABLES) - update_op = opt._update_velocities([(vec1, x), (vec2, y)], 0.5) - opt_vars = [ - v for v in ops.get_collection(ops.GraphKeys.GLOBAL_VARIABLES) - if v not in model_vars - ] - - sess.run(tf_variables.global_variables_initializer()) - old_opt_vars = sess.run(opt_vars) - - # Optimizer vars start out at 0. - for opt_var in old_opt_vars: - self.assertAllEqual(sess.run(array_ops.zeros_like(opt_var)), opt_var) - - sess.run(update_op) - new_opt_vars = sess.run(opt_vars) - # After one update, the velocities are equal to the vectors. - for vec, opt_var in zip([vec1, vec2], new_opt_vars): - self.assertAllEqual(sess.run(vec), opt_var) - - sess.run(update_op) - final_opt_vars = sess.run(opt_vars) - for first, second in zip(new_opt_vars, final_opt_vars): - self.assertFalse(np.equal(first, second).all()) - - def testApplyGradients(self): - with ops.Graph().as_default(), self.cached_session() as sess: - layer_collection = lc.LayerCollection() - - inputs = array_ops.ones((2, 1)) * 2 - weights_val = np.ones((1, 1), dtype=np.float32) * 3. - weights = variable_scope.get_variable( - 'w', initializer=array_ops.constant(weights_val)) - bias = variable_scope.get_variable( - 'b', initializer=init_ops.zeros_initializer(), shape=(1, 1)) - output = math_ops.matmul(inputs, weights) + bias - - layer_collection.register_fully_connected((weights, bias), inputs, output) - - logits = math_ops.tanh(output) - targets = array_ops.constant([[0.], [1.]]) - output = math_ops.reduce_mean( - nn.softmax_cross_entropy_with_logits(logits=logits, labels=targets)) - - layer_collection.register_categorical_predictive_distribution(logits) - - opt = optimizer.KfacOptimizer( - 0.1, - 0.2, - 0.3, - layer_collection, - momentum=0.5, - momentum_type='regular') - (cov_update_thunks, - inv_update_thunks) = opt.make_vars_and_create_op_thunks() - cov_update_ops = tuple(thunk() for thunk in cov_update_thunks) - inv_update_ops = tuple(thunk() for thunk in inv_update_thunks) - - grads_and_vars = opt.compute_gradients(output, [weights, bias]) - all_vars = [grad_and_var[1] for grad_and_var in grads_and_vars] - - op = opt.apply_gradients(grads_and_vars) - - sess.run(tf_variables.global_variables_initializer()) - old_vars = sess.run(all_vars) - sess.run(cov_update_ops) - sess.run(inv_update_ops) - sess.run(op) - new_vars = sess.run(all_vars) - - for old_var, new_var in zip(old_vars, new_vars): - self.assertNotEqual(old_var, new_var) - - -if __name__ == '__main__': - test.main() diff --git a/tensorflow/contrib/kfac/python/kernel_tests/utils_test.py b/tensorflow/contrib/kfac/python/kernel_tests/utils_test.py deleted file mode 100644 index 7df79a3c7f..0000000000 --- a/tensorflow/contrib/kfac/python/kernel_tests/utils_test.py +++ /dev/null @@ -1,410 +0,0 @@ -# Copyright 2017 The TensorFlow Authors. All Rights Reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -# ============================================================================== -"""Tests for tf.contrib.kfac.utils.""" - -from __future__ import absolute_import -from __future__ import division -from __future__ import print_function - -import numpy as np -import numpy.random as npr - -from tensorflow.contrib.kfac.python.ops import utils -from tensorflow.contrib.tpu.python.tpu import tpu_function -from tensorflow.python.framework import dtypes -from tensorflow.python.framework import ops -from tensorflow.python.framework import random_seed -from tensorflow.python.ops import array_ops -from tensorflow.python.ops import linalg_ops -from tensorflow.python.ops import math_ops -from tensorflow.python.ops import nn_ops -from tensorflow.python.ops import random_ops -from tensorflow.python.ops import variable_scope -from tensorflow.python.ops import variables -from tensorflow.python.platform import test - - -class SequenceDictTest(test.TestCase): - - def testSequenceDictInit(self): - seq_dict = utils.SequenceDict() - self.assertFalse(seq_dict._dict) - - def testSequenceDictInitWithIterable(self): - reg_dict = {'a': 'foo', 'b': 'bar'} - itr = zip(reg_dict.keys(), reg_dict.values()) - seq_dict = utils.SequenceDict(itr) - self.assertEqual(reg_dict, seq_dict._dict) - - def testGetItemSingleKey(self): - seq_dict = utils.SequenceDict({'a': 'foo', 'b': 'bar'}) - self.assertEqual('foo', seq_dict['a']) - - def testGetItemMultipleKeys(self): - seq_dict = utils.SequenceDict({'a': 'foo', 'b': 'bar'}) - self.assertEqual(['foo', 'bar'], seq_dict[('a', 'b')]) - - def testSetItemSingleKey(self): - seq_dict = utils.SequenceDict() - seq_dict['a'] = 'foo' - self.assertEqual([('a', 'foo')], seq_dict.items()) - - def testSetItemMultipleKeys(self): - seq_dict = utils.SequenceDict() - keys = ('a', 'b', 'c') - values = ('foo', 'bar', 'baz') - seq_dict[keys] = values - self.assertItemsEqual(list(zip(keys, values)), seq_dict.items()) - - -class SubGraphTest(test.TestCase): - - def testBasicGraph(self): - a = array_ops.constant([[1., 2.], [3., 4.]]) - b = array_ops.constant([[5., 6.], [7., 8.]]) - c = a + b - d = a * b - sub_graph = utils.SubGraph((c,)) - self.assertTrue(sub_graph.is_member(a)) - self.assertTrue(sub_graph.is_member(b)) - self.assertTrue(sub_graph.is_member(c)) - self.assertFalse(sub_graph.is_member(d)) - - def testRepeatedAdds(self): - a = array_ops.constant([[1., 2.], [3., 4.]]) - b = array_ops.constant([[5., 6.], [7., 8.]]) - c = a + b + a # note that a appears twice in this graph - sub_graph = utils.SubGraph((c,)) - self.assertTrue(sub_graph.is_member(a)) - self.assertTrue(sub_graph.is_member(b)) - self.assertTrue(sub_graph.is_member(c)) - - def testFilterList(self): - a = array_ops.constant([[1., 2.], [3., 4.]]) - b = array_ops.constant([[5., 6.], [7., 8.]]) - c = a + b - d = a * b - sub_graph = utils.SubGraph((c,)) - input_list = [b, d] - filtered_list = sub_graph.filter_list(input_list) - self.assertEqual(filtered_list, [b]) - - def testVariableUses(self): - with ops.Graph().as_default(): - var = variable_scope.get_variable('var', shape=[10, 10]) - resource_var = variable_scope.get_variable( - 'resource_var', shape=[10, 10], use_resource=True) - x = array_ops.zeros([3, 10]) - z0 = math_ops.matmul(x, var) + math_ops.matmul(x, var) - z1 = math_ops.matmul(x, resource_var) - sub_graph = utils.SubGraph((z0, z1)) - self.assertEqual(2, sub_graph.variable_uses(var)) - self.assertEqual(1, sub_graph.variable_uses(resource_var)) - - -class UtilsTest(test.TestCase): - - def _fully_connected_layer_params(self): - weights_part = array_ops.constant([[1., 2.], [4., 3.]]) - bias_part = array_ops.constant([1., 2.]) - return (weights_part, bias_part) - - def _conv_layer_params(self): - weights_shape = 2, 2, 3, 4 - biases_shape = weights_shape[-1:] - weights = array_ops.constant(npr.RandomState(0).randn(*weights_shape)) - biases = array_ops.constant(npr.RandomState(1).randn(*biases_shape)) - return (weights, biases) - - def testFullyConnectedLayerParamsTupleToMat2d(self): - with ops.Graph().as_default(), self.cached_session() as sess: - random_seed.set_random_seed(200) - layer_params = self._fully_connected_layer_params() - output = utils.layer_params_to_mat2d(layer_params) - self.assertListEqual([3, 2], output.get_shape().as_list()) - self.assertAllClose( - sess.run(output), np.array([[1., 2.], [4., 3.], [1., 2.]])) - - def testFullyConnectedLayerParamsTensorToMat2d(self): - with ops.Graph().as_default(), self.cached_session() as sess: - random_seed.set_random_seed(200) - layer_params = self._fully_connected_layer_params() - output = utils.layer_params_to_mat2d(layer_params[0]) - self.assertListEqual([2, 2], output.get_shape().as_list()) - self.assertAllClose(sess.run(output), np.array([[1., 2.], [4., 3.]])) - - def testConvLayerParamsTupleToMat2d(self): - with ops.Graph().as_default(): - random_seed.set_random_seed(200) - layer_params = self._conv_layer_params() - output = utils.layer_params_to_mat2d(layer_params) - self.assertListEqual([2 * 2 * 3 + 1, 4], output.get_shape().as_list()) - - def testKron(self): - with ops.Graph().as_default(), self.cached_session() as sess: - mat1 = np.array([[1., 2.], [3., 4.]]) - mat2 = np.array([[5., 6.], [7., 8.]]) - mat1_tf = array_ops.constant(mat1) - mat2_tf = array_ops.constant(mat2) - ans_tf = sess.run(utils.kronecker_product(mat1_tf, mat2_tf)) - ans_np = np.kron(mat1, mat2) - self.assertAllClose(ans_tf, ans_np) - - def testMat2dToFullyConnectedLayerParamsTuple(self): - with ops.Graph().as_default(), self.cached_session() as sess: - random_seed.set_random_seed(200) - vector_template = self._fully_connected_layer_params() - mat2d = array_ops.constant([[5., 4.], [3., 2.], [1., 0.]]) - - output = sess.run(utils.mat2d_to_layer_params(vector_template, mat2d)) - - self.assertIsInstance(output, tuple) - self.assertEqual(len(output), 2) - a, b = output - self.assertAllClose(a, np.array([[5., 4.], [3., 2.]])) - self.assertAllClose(b, np.array([1., 0.])) - - def testMat2dToFullyConnectedLayerParamsTensor(self): - with ops.Graph().as_default(), self.cached_session() as sess: - random_seed.set_random_seed(200) - vector_template = self._fully_connected_layer_params()[0] - mat2d = array_ops.constant([[5., 4.], [3., 2.]]) - - output = sess.run(utils.mat2d_to_layer_params(vector_template, mat2d)) - - self.assertAllClose(output, np.array([[5., 4.], [3., 2.]])) - - def testTensorsToColumn(self): - with ops.Graph().as_default(), self.cached_session() as sess: - random_seed.set_random_seed(200) - - vector = array_ops.constant(np.array([[0., 1.], [2., 3.]])) - output = utils.tensors_to_column(vector) - self.assertListEqual([4, 1], output.get_shape().as_list()) - self.assertAllClose(sess.run(output), np.array([0., 1., 2., 3.])[:, None]) - - vector = self._fully_connected_layer_params() - output = utils.tensors_to_column(vector) - self.assertListEqual([6, 1], output.get_shape().as_list()) - self.assertAllClose( - sess.run(output), np.array([1., 2., 4., 3., 1., 2.])[:, None]) - - vector = list(vector) - vector.append(array_ops.constant([[6.], [7.], [8.], [9.]])) - - output = utils.tensors_to_column(vector) - self.assertListEqual([10, 1], output.get_shape().as_list()) - self.assertAllClose( - sess.run(output), - np.array([1., 2., 4., 3., 1., 2., 6., 7., 8., 9.])[:, None]) - - def testColumnToTensors(self): - with ops.Graph().as_default(), self.cached_session() as sess: - random_seed.set_random_seed(200) - - vector_template = array_ops.constant(np.array([[0., 1.], [2., 3.]])) - colvec = array_ops.constant(np.arange(4.)[:, None]) - output = sess.run(utils.column_to_tensors(vector_template, colvec)) - self.assertAllClose(output, np.array([[0., 1.], [2., 3.]])) - - vector_template = self._fully_connected_layer_params() - colvec = array_ops.constant(np.arange(6.)[:, None]) - output = sess.run(utils.column_to_tensors(vector_template, colvec)) - - self.assertIsInstance(output, tuple) - self.assertEqual(len(output), 2) - a, b = output - self.assertAllClose(a, np.array([[0., 1.], [2., 3.]])) - self.assertAllClose(b, np.array([4., 5.])) - - vector_template = list(vector_template) - vector_template.append(array_ops.constant([[6.], [7.], [8.], [9.]])) - colvec = array_ops.constant(np.arange(10.)[:, None]) - output = sess.run(utils.column_to_tensors(vector_template, colvec)) - self.assertIsInstance(output, tuple) - self.assertEqual(len(output), 3) - a, b, c = output - self.assertAllClose(a, np.array([[0., 1.], [2., 3.]])) - self.assertAllClose(b, np.array([4., 5.])) - self.assertAllClose(c, np.array([[6.], [7.], [8.], [9.]])) - - def testPosDefInvCholesky(self): - with ops.Graph().as_default(), self.cached_session() as sess: - random_seed.set_random_seed(200) - npr.seed(0) - square = lambda x: np.dot(x, x.T) - - size = 3 - x = square(npr.randn(size, size)) - damp = 0.1 - identity = linalg_ops.eye(size, dtype=dtypes.float64) - - tf_inv = utils.posdef_inv_cholesky(array_ops.constant(x), identity, damp) - np_inv = np.linalg.inv(x + damp * np.eye(size)) - self.assertAllClose(sess.run(tf_inv), np_inv) - - def testPosDefInvMatrixInverse(self): - with ops.Graph().as_default(), self.cached_session() as sess: - random_seed.set_random_seed(200) - npr.seed(0) - square = lambda x: np.dot(x, x.T) - - size = 3 - x = square(npr.randn(size, size)) - damp = 0.1 - identity = linalg_ops.eye(size, dtype=dtypes.float64) - - tf_inv = utils.posdef_inv_matrix_inverse( - array_ops.constant(x), identity, damp) - np_inv = np.linalg.inv(x + damp * np.eye(size)) - self.assertAllClose(sess.run(tf_inv), np_inv) - - def testCrossReplicaMean(self): - """Ensures that cross_replica_mean() executes only when num_shards > 1.""" - with ops.Graph().as_default(): - with tpu_function.tpu_shard_context(4): - tensor = array_ops.zeros([], dtype=dtypes.float32) - mean = utils.cross_replica_mean(tensor) - self.assertNotEqual(mean, tensor) - - with ops.Graph().as_default(): - with tpu_function.tpu_shard_context(1): - tensor = array_ops.zeros([], dtype=dtypes.float32) - mean = utils.cross_replica_mean(tensor) - self.assertEqual(mean, tensor) - - with ops.Graph().as_default(): - with self.assertRaises(ValueError): # Outside of TPU context. - tensor = array_ops.zeros([], dtype=dtypes.float32) - mean = utils.cross_replica_mean(tensor) - - def testBatchExecute(self): - """Ensure batch_execute runs in a round-robin fashion.""" - - def increment_var(var): - return lambda: var.assign_add(1) - - with ops.Graph().as_default(), self.cached_session() as sess: - i = variable_scope.get_variable('i', initializer=0) - accumulators = [ - variable_scope.get_variable('var%d' % j, initializer=0) - for j in range(3) - ] - thunks = [increment_var(var) for var in accumulators] - increment_accumulators = utils.batch_execute(i, thunks, 2) - increment_i = i.assign_add(1) - - sess.run(variables.global_variables_initializer()) - - # Ensure one op per thunk. - self.assertEqual(3, len(increment_accumulators)) - - # Ensure round-robin execution. - values = [] - for _ in range(5): - sess.run(increment_accumulators) - sess.run(increment_i) - values.append(sess.run(accumulators)) - self.assertAllClose( - [ - [1, 1, 0], # - [2, 1, 1], # - [2, 2, 2], # - [3, 3, 2], # - [4, 3, 3] - ], - values) - - def testExtractConvolutionPatches(self): - with ops.Graph().as_default(), self.cached_session() as sess: - batch_size = 10 - image_spatial_shape = [9, 10, 11] - in_channels = out_channels = 32 - kernel_spatial_shape = [5, 3, 3] - spatial_strides = [1, 2, 1] - spatial_dilation = [1, 1, 1] - padding = 'SAME' - - images = random_ops.random_uniform( - [batch_size] + image_spatial_shape + [in_channels], seed=0) - kernel_shape = kernel_spatial_shape + [in_channels, out_channels] - kernel = random_ops.random_uniform(kernel_shape, seed=1) - - # Ensure shape matches expectation. - patches = utils.extract_convolution_patches( - images, - kernel_shape, - padding, - strides=spatial_strides, - dilation_rate=spatial_dilation) - result_spatial_shape = ( - patches.shape.as_list()[1:1 + len(image_spatial_shape)]) - self.assertEqual(patches.shape.as_list(), - [batch_size] + result_spatial_shape + - kernel_spatial_shape + [in_channels]) - - # Ensure extract...patches() + matmul() and convolution() implementation - # give the same answer. - outputs = nn_ops.convolution( - images, - kernel, - padding, - strides=spatial_strides, - dilation_rate=spatial_dilation) - - patches_flat = array_ops.reshape( - patches, [-1, np.prod(kernel_spatial_shape) * in_channels]) - kernel_flat = array_ops.reshape(kernel, [-1, out_channels]) - outputs_flat = math_ops.matmul(patches_flat, kernel_flat) - - outputs_, outputs_flat_ = sess.run([outputs, outputs_flat]) - self.assertAllClose(outputs_.flatten(), outputs_flat_.flatten()) - - def testExtractPointwiseConv2dPatches(self): - with ops.Graph().as_default(), self.cached_session() as sess: - batch_size = 10 - image_height = image_width = 8 - in_channels = out_channels = 3 - kernel_height = kernel_width = 1 - strides = [1, 1, 1, 1] - padding = 'VALID' - - images = random_ops.random_uniform( - [batch_size, image_height, image_width, in_channels], seed=0) - kernel_shape = [kernel_height, kernel_width, in_channels, out_channels] - kernel = random_ops.random_uniform(kernel_shape, seed=1) - - # Ensure shape matches expectation. - patches = utils.extract_pointwise_conv2d_patches(images, kernel_shape) - self.assertEqual(patches.shape.as_list(), [ - batch_size, image_height, image_width, kernel_height, kernel_width, - in_channels - ]) - - # Ensure extract...patches() + matmul() and conv2d() implementation - # give the same answer. - outputs = nn_ops.conv2d(images, kernel, strides, padding) - - patches_flat = array_ops.reshape( - patches, [-1, kernel_height * kernel_width * in_channels]) - kernel_flat = array_ops.reshape(kernel, [-1, out_channels]) - outputs_flat = math_ops.matmul(patches_flat, kernel_flat) - - outputs_, outputs_flat_ = sess.run([outputs, outputs_flat]) - self.assertAllClose(outputs_.flatten(), outputs_flat_.flatten()) - - -if __name__ == '__main__': - test.main() diff --git a/tensorflow/contrib/kfac/python/ops/BUILD b/tensorflow/contrib/kfac/python/ops/BUILD deleted file mode 100644 index 3c01eb65e7..0000000000 --- a/tensorflow/contrib/kfac/python/ops/BUILD +++ /dev/null @@ -1,263 +0,0 @@ -package(default_visibility = [ - "//tensorflow/contrib/kfac:__pkg__", - "//tensorflow/contrib/kfac/python/kernel_tests:__pkg__", -]) - -licenses(["notice"]) # Apache 2.0 - -exports_files(["LICENSE"]) - -py_library( - name = "fisher_blocks", - srcs = ["fisher_blocks.py"], - srcs_version = "PY2AND3", - deps = [ - ":fisher_factors", - ":utils", - "//tensorflow/python:array_ops", - "//tensorflow/python:math_ops", - "@six_archive//:six", - ], -) - -py_library( - name = "fisher_blocks_lib", - srcs = ["fisher_blocks_lib.py"], - srcs_version = "PY2AND3", - deps = [ - ":fisher_blocks", - "//tensorflow/python:util", - ], -) - -py_library( - name = "fisher_factors", - srcs = ["fisher_factors.py"], - srcs_version = "PY2AND3", - deps = [ - ":linear_operator", - ":utils", - "//tensorflow/python:array_ops", - "//tensorflow/python:control_flow_ops", - "//tensorflow/python:dtypes", - "//tensorflow/python:framework_ops", - "//tensorflow/python:init_ops", - "//tensorflow/python:linalg_ops", - "//tensorflow/python:math_ops", - "//tensorflow/python:random_ops", - "//tensorflow/python:special_math_ops", - "//tensorflow/python:training", - "//tensorflow/python:variable_scope", - "//tensorflow/python:variables", - "//third_party/py/numpy", - "@six_archive//:six", - ], -) - -py_library( - name = "fisher_factors_lib", - srcs = ["fisher_factors_lib.py"], - srcs_version = "PY2AND3", - deps = [ - ":fisher_factors", - "//tensorflow/python:util", - ], -) - -py_library( - name = "linear_operator", - srcs = ["linear_operator.py"], - srcs_version = "PY2AND3", - deps = [ - ":utils", - "//tensorflow/python:framework_ops", - "//tensorflow/python:math_ops", - "//tensorflow/python/ops/linalg", - "@six_archive//:six", - ], -) - -py_library( - name = "loss_functions", - srcs = ["loss_functions.py"], - srcs_version = "PY2AND3", - deps = [ - "//tensorflow/contrib/distributions:distributions_py", - "//tensorflow/python:array_ops", - "//tensorflow/python:math_ops", - "//tensorflow/python:tensor_shape", - "//tensorflow/python/ops/distributions", - "@six_archive//:six", - ], -) - -py_library( - name = "loss_functions_lib", - srcs = ["loss_functions_lib.py"], - srcs_version = "PY2AND3", - deps = [ - ":loss_functions", - "//tensorflow/python:util", - ], -) - -py_library( - name = "curvature_matrix_vector_products", - srcs = ["curvature_matrix_vector_products.py"], - srcs_version = "PY2AND3", - deps = [ - ":utils", - "//tensorflow/python:gradients", - "//tensorflow/python:math_ops", - "//tensorflow/python:util", - ], -) - -py_library( - name = "curvature_matrix_vector_products_lib", - srcs = ["curvature_matrix_vector_products_lib.py"], - srcs_version = "PY2AND3", - deps = [ - ":curvature_matrix_vector_products", - "//tensorflow/python:util", - ], -) - -py_library( - name = "layer_collection", - srcs = ["layer_collection.py"], - srcs_version = "PY2AND3", - deps = [ - ":fisher_blocks", - ":loss_functions", - ":utils", - "//tensorflow/python:framework_ops", - "//tensorflow/python:math_ops", - "//tensorflow/python:platform", - "//tensorflow/python:util", - "//tensorflow/python:variable_scope", - "@six_archive//:six", - ], -) - -py_library( - name = "layer_collection_lib", - srcs = ["layer_collection_lib.py"], - srcs_version = "PY2AND3", - deps = [ - ":layer_collection", - "//tensorflow/python:util", - ], -) - -py_library( - name = "kfac_optimizer", - srcs = [ - "optimizer.py", - ], - srcs_version = "PY2AND3", - deps = [ - ":curvature_matrix_vector_products", - ":fisher_estimator", - "//tensorflow/python:array_ops", - "//tensorflow/python:control_flow_ops", - "//tensorflow/python:dtypes", - "//tensorflow/python:framework_ops", - "//tensorflow/python:linalg_ops", - "//tensorflow/python:math_ops", - "//tensorflow/python:state_ops", - "//tensorflow/python:training", - "//tensorflow/python:variable_scope", - "//tensorflow/python:variables", - ], -) - -py_library( - name = "kfac_optimizer_lib", - srcs = [ - "optimizer_lib.py", - ], - srcs_version = "PY2AND3", - deps = [ - ":kfac_optimizer", - "//tensorflow/python:util", - ], -) - -py_library( - name = "fisher_estimator", - srcs = [ - "estimator.py", - "placement.py", - ], - srcs_version = "PY2AND3", - deps = [ - ":utils", - "//tensorflow/python:control_flow_ops", - "//tensorflow/python:framework_ops", - "//tensorflow/python:gradients", - "//tensorflow/python:util", - "//third_party/py/numpy", - "@six_archive//:six", - ], -) - -py_library( - name = "fisher_estimator_lib", - srcs = [ - "estimator_lib.py", - ], - srcs_version = "PY2AND3", - deps = [ - ":fisher_estimator", - "//tensorflow/python:util", - ], -) - -py_library( - name = "utils", - srcs = ["utils.py"], - srcs_version = "PY2AND3", - deps = [ - "//tensorflow/contrib/tpu", - "//tensorflow/python:array_ops", - "//tensorflow/python:control_flow_ops", - "//tensorflow/python:dtypes", - "//tensorflow/python:framework_ops", - "//tensorflow/python:gradients", - "//tensorflow/python:linalg_ops", - "//tensorflow/python:math_ops", - "//tensorflow/python:random_ops", - "//third_party/py/numpy", - ], -) - -py_library( - name = "utils_lib", - srcs = ["utils_lib.py"], - srcs_version = "PY2AND3", - deps = [ - ":utils", - "//tensorflow/python:util", - ], -) - -py_library( - name = "op_queue", - srcs = ["op_queue.py"], - srcs_version = "PY2AND3", - deps = [ - "//tensorflow/contrib/data/python/ops:dataset_ops", - "//tensorflow/python:framework_ops", - ], -) - -py_library( - name = "op_queue_lib", - srcs = ["op_queue_lib.py"], - srcs_version = "PY2AND3", - deps = [ - ":op_queue", - "//tensorflow/python:util", - ], -) diff --git a/tensorflow/contrib/kfac/python/ops/curvature_matrix_vector_products.py b/tensorflow/contrib/kfac/python/ops/curvature_matrix_vector_products.py deleted file mode 100644 index 21b5cde9b9..0000000000 --- a/tensorflow/contrib/kfac/python/ops/curvature_matrix_vector_products.py +++ /dev/null @@ -1,183 +0,0 @@ -# Copyright 2017 The TensorFlow Authors. All Rights Reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -# ============================================================================== -"""Curvature matrix-vector multiplication.""" - -from __future__ import absolute_import -from __future__ import division -from __future__ import print_function - -from tensorflow.contrib.kfac.python.ops import utils -from tensorflow.python.ops import gradients_impl -from tensorflow.python.ops import math_ops -from tensorflow.python.util import nest - - -class CurvatureMatrixVectorProductComputer(object): - """Class for computing matrix-vector products for Fishers, GGNs and Hessians. - - In other words we compute M*v where M is the matrix, v is the vector, and - * refers to standard matrix/vector multiplication (not element-wise - multiplication). - - The matrices are defined in terms of some differential quantity of the total - loss function with respect to a provided list of tensors ("wrt_tensors"). - For example, the Fisher associated with a log-prob loss w.r.t. the - parameters. - - The 'vecs' argument to each method are lists of tensors that must be the - size as the corresponding ones from "wrt_tensors". They represent - the vector being multiplied. - - "factors" of the matrix M are defined as matrices B such that B*B^T = M. - Methods that multiply by the factor B take a 'loss_inner_vecs' argument - instead of 'vecs', which must be a list of tensors with shapes given by the - corresponding XXX_inner_shapes property. - - Note that matrix-vector products are not normalized by the batch size, nor - are any damping terms added to the results. These things can be easily - applied externally, if desired. - - See for example: www.cs.utoronto.ca/~jmartens/docs/HF_book_chapter.pdf - and https://arxiv.org/abs/1412.1193 for more information about the - generalized Gauss-Newton, Fisher, etc., and how to compute matrix-vector - products. - """ - - def __init__(self, losses, wrt_tensors): - """Create a CurvatureMatrixVectorProductComputer object. - - Args: - losses: A list of LossFunction instances whose sum defines the total loss. - wrt_tensors: A list of Tensors to compute the differential quantities - (defining the matrices) with respect to. See class description for more - info. - """ - self._losses = losses - self._inputs_to_losses = list(loss.inputs for loss in losses) - self._inputs_to_losses_flat = nest.flatten(self._inputs_to_losses) - self._wrt_tensors = wrt_tensors - - @property - def _total_loss(self): - return math_ops.add_n(tuple(loss.evaluate() for loss in self._losses)) - - # Jacobian multiplication functions: - def _multiply_jacobian(self, vecs): - """Multiply vecs by the Jacobian of losses.""" - # We stop gradients at wrt_tensors to produce partial derivatives (which is - # what we want for Jacobians). - jacobian_vecs_flat = utils.fwd_gradients( - self._inputs_to_losses_flat, self._wrt_tensors, grad_xs=vecs, - stop_gradients=self._wrt_tensors) - return nest.pack_sequence_as(self._inputs_to_losses, jacobian_vecs_flat) - - def _multiply_jacobian_transpose(self, loss_vecs): - """Multiply vecs by the transpose Jacobian of losses.""" - loss_vecs_flat = nest.flatten(loss_vecs) - # We stop gradients at wrt_tensors to produce partial derivatives (which is - # what we want for Jacobians). - return gradients_impl.gradients( - self._inputs_to_losses_flat, self._wrt_tensors, grad_ys=loss_vecs_flat, - stop_gradients=self._wrt_tensors) - - # Losses Fisher/Hessian multiplication functions: - def _multiply_loss_fisher(self, loss_vecs): - """Multiply loss_vecs by Fisher of total loss.""" - return tuple( - loss.multiply_fisher(loss_vec) - for loss, loss_vec in zip(self._losses, loss_vecs)) - - def _multiply_loss_fisher_factor(self, loss_inner_vecs): - """Multiply loss_inner_vecs by factor of Fisher of total loss.""" - return tuple( - loss.multiply_fisher_factor(loss_vec) - for loss, loss_vec in zip(self._losses, loss_inner_vecs)) - - def _multiply_loss_fisher_factor_transpose(self, loss_vecs): - """Multiply loss_vecs by transpose factor of Fisher of total loss.""" - return tuple( - loss.multiply_fisher_factor_transpose(loss_vec) - for loss, loss_vec in zip(self._losses, loss_vecs)) - - def _multiply_loss_hessian(self, loss_vecs): - """Multiply loss_vecs by Hessian of total loss.""" - return tuple( - loss.multiply_hessian(loss_vec) - for loss, loss_vec in zip(self._losses, loss_vecs)) - - def _multiply_loss_hessian_factor(self, loss_inner_vecs): - """Multiply loss_inner_vecs by factor of Hessian of total loss.""" - return tuple( - loss.multiply_hessian_factor(loss_vec) - for loss, loss_vec in zip(self._losses, loss_inner_vecs)) - - def _multiply_loss_hessian_factor_transpose(self, loss_vecs): - """Multiply loss_vecs by transpose factor of Hessian of total loss.""" - return tuple( - loss.multiply_hessian_factor_transpose(loss_vec) - for loss, loss_vec in zip(self._losses, loss_vecs)) - - # Matrix-vector product functions: - def multiply_fisher(self, vecs): - """Multiply vecs by Fisher of total loss.""" - jacobian_vecs = self._multiply_jacobian(vecs) - loss_fisher_jacobian_vecs = self._multiply_loss_fisher(jacobian_vecs) - return self._multiply_jacobian_transpose(loss_fisher_jacobian_vecs) - - def multiply_fisher_factor_transpose(self, vecs): - """Multiply vecs by transpose of factor of Fisher of total loss.""" - jacobian_vecs = self._multiply_jacobian(vecs) - return self._multiply_loss_fisher_factor_transpose(jacobian_vecs) - - def multiply_fisher_factor(self, loss_inner_vecs): - """Multiply loss_inner_vecs by factor of Fisher of total loss.""" - fisher_factor_transpose_vecs = self._multiply_loss_fisher_factor_transpose( - loss_inner_vecs) - return self._multiply_jacobian_transpose(fisher_factor_transpose_vecs) - - def multiply_hessian(self, vecs): - """Multiply vecs by Hessian of total loss.""" - return gradients_impl.gradients( - gradients_impl.gradients(self._total_loss, self._wrt_tensors), - self._wrt_tensors, - grad_ys=vecs) - - def multiply_generalized_gauss_newton(self, vecs): - """Multiply vecs by generalized Gauss-Newton of total loss.""" - jacobian_vecs = self._multiply_jacobian(vecs) - loss_hessian_jacobian_vecs = self._multiply_loss_hessian(jacobian_vecs) - return self._multiply_jacobian_transpose(loss_hessian_jacobian_vecs) - - def multiply_generalized_gauss_newton_factor_transpose(self, vecs): - """Multiply vecs by transpose of factor of GGN of total loss.""" - jacobian_vecs = self._multiply_jacobian(vecs) - return self._multiply_loss_hessian_factor_transpose(jacobian_vecs) - - def multiply_generalized_gauss_newton_factor(self, loss_inner_vecs): - """Multiply loss_inner_vecs by factor of GGN of total loss.""" - hessian_factor_transpose_vecs = ( - self._multiply_loss_hessian_factor_transpose(loss_inner_vecs)) - return self._multiply_jacobian_transpose(hessian_factor_transpose_vecs) - - # Shape properties for multiply_XXX_factor methods: - @property - def fisher_factor_inner_shapes(self): - """Shapes required by multiply_fisher_factor.""" - return tuple(loss.fisher_factor_inner_shape for loss in self._losses) - - @property - def generalized_gauss_newton_factor_inner_shapes(self): - """Shapes required by multiply_generalized_gauss_newton_factor.""" - return tuple(loss.hessian_factor_inner_shape for loss in self._losses) diff --git a/tensorflow/contrib/kfac/python/ops/curvature_matrix_vector_products_lib.py b/tensorflow/contrib/kfac/python/ops/curvature_matrix_vector_products_lib.py deleted file mode 100644 index 6e8c6404dc..0000000000 --- a/tensorflow/contrib/kfac/python/ops/curvature_matrix_vector_products_lib.py +++ /dev/null @@ -1,30 +0,0 @@ -# Copyright 2017 The TensorFlow Authors. All Rights Reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -# ============================================================================== -"""Curvature matrix-vector multiplication.""" - -from __future__ import absolute_import -from __future__ import division -from __future__ import print_function - -# pylint: disable=unused-import,line-too-long,wildcard-import -from tensorflow.contrib.kfac.python.ops.curvature_matrix_vector_products import * -from tensorflow.python.util.all_util import remove_undocumented -# pylint: enable=unused-import,line-too-long,wildcard-import - -_allowed_symbols = [ - 'CurvatureMatrixVectorProductComputer', -] - -remove_undocumented(__name__, allowed_exception_list=_allowed_symbols) diff --git a/tensorflow/contrib/kfac/python/ops/estimator.py b/tensorflow/contrib/kfac/python/ops/estimator.py deleted file mode 100644 index 323234c403..0000000000 --- a/tensorflow/contrib/kfac/python/ops/estimator.py +++ /dev/null @@ -1,516 +0,0 @@ -# Copyright 2017 The TensorFlow Authors. All Rights Reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -# ============================================================================== -"""Defines the high-level Fisher estimator class.""" - -from __future__ import absolute_import -from __future__ import division -from __future__ import print_function - -import abc -import numpy as np -import six - -from tensorflow.contrib.kfac.python.ops import placement -from tensorflow.contrib.kfac.python.ops import utils -from tensorflow.python.framework import ops as tf_ops -from tensorflow.python.ops import control_flow_ops -from tensorflow.python.ops import gradients_impl -from tensorflow.python.ops import variable_scope -from tensorflow.python.util import nest - - -# The linter is confused. -# pylint: disable=abstract-class-instantiated -def make_fisher_estimator(placement_strategy=None, **kwargs): - """Creates Fisher estimator instances based on the placement strategy. - - For example if the `placement_strategy` is 'round_robin' then - `FisherEstimatorRoundRobin` instance is returned. - - Args: - placement_strategy: `string`, Strategy to be used for placing covariance - variables, covariance ops and inverse ops. Check - `placement.FisherEstimatorRoundRobin` for a concrete example. - **kwargs: Arguments to be passed into `FisherEstimator` class initializer. - - Returns: - An instance of class which inherits from `FisherEstimator` and the mixin - which implements specific placement strategy. See, - `FisherEstimatorRoundRobin` which inherits from `FisherEstimator` and - `RoundRobinPlacementMixin`. - - Raises: - ValueError: If the `placement_strategy` is not equal to 'round_robin'. - """ - if placement_strategy in [None, "round_robin"]: - return FisherEstimatorRoundRobin(**kwargs) - else: - raise ValueError("Unimplemented vars and ops " - "placement strategy : {}".format(placement_strategy)) -# pylint: enable=abstract-class-instantiated - - -@six.add_metaclass(abc.ABCMeta) -class FisherEstimator(object): - """Fisher estimator class supporting various approximations of the Fisher. - - This is an abstract base class which does not implement a strategy for - placing covariance variables, covariance update ops and inverse update ops. - The placement strategies are implemented in `placement.py`. See - `FisherEstimatorRoundRobin` for example of a concrete subclass with - a round-robin placement strategy. - """ - - def __init__(self, - variables, - cov_ema_decay, - damping, - layer_collection, - exps=(-1,), - estimation_mode="gradients", - colocate_gradients_with_ops=True, - name="FisherEstimator", - compute_cholesky=False, - compute_cholesky_inverse=False): - """Create a FisherEstimator object. - - Args: - variables: A `list` of variables or `callable` which returns the variables - for which to estimate the Fisher. This must match the variables - registered in layer_collection (if it is not None). - cov_ema_decay: The decay factor used when calculating the covariance - estimate moving averages. - damping: float. The damping factor used to stabilize training due to - errors in the local approximation with the Fisher information matrix, - and to regularize the update direction by making it closer to the - gradient. (Higher damping means the update looks more like a standard - gradient update - see Tikhonov regularization.) - layer_collection: The layer collection object, which holds the Fisher - blocks, Kronecker factors, and losses associated with the - graph. - exps: List of floats or ints. These represent the different matrix - powers of the approximate Fisher that the FisherEstimator will be able - to multiply vectors by. If the user asks for a matrix power other - one of these (or 1, which is always supported), there will be a - failure. (Default: (-1,)) - estimation_mode: The type of estimator to use for the Fishers. Can be - 'gradients', 'empirical', 'curvature_prop', or 'exact'. - (Default: 'gradients'). 'gradients' is the basic estimation approach - from the original K-FAC paper. 'empirical' computes the 'empirical' - Fisher information matrix (which uses the data's distribution for the - targets, as opposed to the true Fisher which uses the model's - distribution) and requires that each registered loss have specified - targets. 'curvature_propagation' is a method which estimates the - Fisher using self-products of random 1/-1 vectors times "half-factors" - of the Fisher, as described here: https://arxiv.org/abs/1206.6464 . - Finally, 'exact' is the obvious generalization of Curvature - Propagation to compute the exact Fisher (modulo any additional - diagonal or Kronecker approximations) by looping over one-hot vectors - for each coordinate of the output instead of using 1/-1 vectors. It - is more expensive to compute than the other three options by a factor - equal to the output dimension, roughly speaking. - colocate_gradients_with_ops: Whether we should request gradients be - colocated with their respective ops. (Default: True) - name: A string. A name given to this estimator, which is added to the - variable scope when constructing variables and ops. - (Default: "FisherEstimator") - compute_cholesky: Bool. Whether or not the FisherEstimator will be - able to multiply vectors by the Cholesky factor. - (Default: False) - compute_cholesky_inverse: Bool. Whether or not the FisherEstimator - will be able to multiply vectors by the Cholesky factor inverse. - (Default: False) - Raises: - ValueError: If no losses have been registered with layer_collection. - """ - self._variables = variables - self._cov_ema_decay = cov_ema_decay - self._damping = damping - self._estimation_mode = estimation_mode - self._layers = layer_collection - self._gradient_fns = { - "gradients": self._get_grads_lists_gradients, - "empirical": self._get_grads_lists_empirical, - "curvature_prop": self._get_grads_lists_curvature_prop, - "exact": self._get_grads_lists_exact - } - self._colocate_gradients_with_ops = colocate_gradients_with_ops - - self._made_vars = False - self._exps = exps - self._compute_cholesky = compute_cholesky - self._compute_cholesky_inverse = compute_cholesky_inverse - - self._name = name - - @property - def variables(self): - if callable(self._variables): - return self._variables() - else: - return self._variables - - @property - def damping(self): - return self._damping - - @property - def blocks(self): - """All registered FisherBlocks.""" - return self._layers.get_blocks() - - @property - def factors(self): - """All registered FisherFactors.""" - return self._layers.get_factors() - - @property - def name(self): - return self._name - - @abc.abstractmethod - def make_vars_and_create_op_thunks(self, scope=None): - """Make vars and create op thunks with a specific placement strategy. - - For each factor, all of that factor's cov variables and their associated - update ops will be placed on a particular device. A new device is chosen - for each factor by cycling through list of devices in the cov_devices - argument. If cov_devices is None then no explicit device placement occurs. - - An analogous strategy is followed for inverse update ops, with the list of - devices being given by the inv_devices argument. - - Inverse variables on the other hand are not placed on any specific device - (they will just use the current the device placement context, whatever - that happens to be). The idea is that the inverse variable belong where - they will be accessed most often, which is the device that actually applies - the preconditioner to the gradient. The user will be responsible for setting - the device context for this. - - Args: - scope: A string or None. If None it will be set to the name of this - estimator (given by the name property). All variables will be created, - and all thunks will execute, inside of a variable scope of the given - name. (Default: None) - - Returns: - cov_update_thunks: List of cov update thunks. Corresponds one-to-one with - the list of factors given by the "factors" property. - inv_update_thunks: List of inv update thunks. Corresponds one-to-one with - the list of factors given by the "factors" property. - """ - pass - - def _apply_transformation(self, vecs_and_vars, transform): - """Applies an block-wise transformation to the corresponding vectors. - - Args: - vecs_and_vars: List of (vector, variable) pairs. - transform: A function of the form f(fb, vec), where vec is the vector - to transform and fb is its corresponding block in the matrix, that - returns the transformed vector. - - Returns: - A list of (transformed vector, var) pairs in the same order as - vecs_and_vars. - """ - - vecs = utils.SequenceDict((var, vec) for vec, var in vecs_and_vars) - - trans_vecs = utils.SequenceDict() - - for params, fb in self._layers.fisher_blocks.items(): - trans_vecs[params] = transform(fb, vecs[params]) - - return [(trans_vecs[var], var) for _, var in vecs_and_vars] - - def multiply_inverse(self, vecs_and_vars): - """Multiplies the vecs by the corresponding (damped) inverses of the blocks. - - Args: - vecs_and_vars: List of (vector, variable) pairs. - - Returns: - A list of (transformed vector, var) pairs in the same order as - vecs_and_vars. - """ - return self.multiply_matpower(-1, vecs_and_vars) - - def multiply(self, vecs_and_vars): - """Multiplies the vectors by the corresponding (damped) blocks. - - Args: - vecs_and_vars: List of (vector, variable) pairs. - - Returns: - A list of (transformed vector, var) pairs in the same order as - vecs_and_vars. - """ - return self.multiply_matpower(1, vecs_and_vars) - - def multiply_matpower(self, exp, vecs_and_vars): - """Multiplies the vecs by the corresponding matrix powers of the blocks. - - Args: - exp: A float representing the power to raise the blocks by before - multiplying it by the vector. - vecs_and_vars: List of (vector, variable) pairs. - - Returns: - A list of (transformed vector, var) pairs in the same order as - vecs_and_vars. - """ - assert exp in self._exps - - fcn = lambda fb, vec: fb.multiply_matpower(vec, exp) - return self._apply_transformation(vecs_and_vars, fcn) - - def multiply_cholesky(self, vecs_and_vars, transpose=False): - """Multiplies the vecs by the corresponding Cholesky factors. - - Args: - vecs_and_vars: List of (vector, variable) pairs. - transpose: Bool. If true the Cholesky factors are transposed before - multiplying the vecs. (Default: False) - - Returns: - A list of (transformed vector, var) pairs in the same order as - vecs_and_vars. - """ - assert self._compute_cholesky - - fcn = lambda fb, vec: fb.multiply_cholesky(vec, transpose=transpose) - return self._apply_transformation(vecs_and_vars, fcn) - - def multiply_cholesky_inverse(self, vecs_and_vars, transpose=False): - """Mults the vecs by the inverses of the corresponding Cholesky factors. - - Note: if you are using Cholesky inverse multiplication to sample from - a matrix-variate Gaussian you will want to multiply by the transpose. - Let L be the Cholesky factor of F and observe that - - L^-T * L^-1 = (L * L^T)^-1 = F^-1 . - - Thus we want to multiply by L^-T in order to sample from Gaussian with - covariance F^-1. - - Args: - vecs_and_vars: List of (vector, variable) pairs. - transpose: Bool. If true the Cholesky factor inverses are transposed - before multiplying the vecs. (Default: False) - - Returns: - A list of (transformed vector, var) pairs in the same order as - vecs_and_vars. - """ - assert self._compute_cholesky_inverse - - fcn = lambda fb, vec: fb.multiply_cholesky_inverse(vec, transpose=transpose) - return self._apply_transformation(vecs_and_vars, fcn) - - def _instantiate_factors(self): - """Instantiates FisherFactors' variables. - - Raises: - ValueError: If estimation_mode was improperly specified at construction. - """ - blocks = self.blocks - tensors_to_compute_grads = [ - block.tensors_to_compute_grads() for block in blocks - ] - - try: - grads_lists = self._gradient_fns[self._estimation_mode]( - tensors_to_compute_grads) - except KeyError: - raise ValueError("Unrecognized value {} for estimation_mode.".format( - self._estimation_mode)) - - for grads_list, block in zip(grads_lists, blocks): - block.instantiate_factors(grads_list, self.damping) - - def _check_vars_unmade_and_set_made_flag(self): - if self._made_vars: - raise Exception("Already made variables.") - self._made_vars = True - - def made_vars(self): - return self._made_vars - - def _register_matrix_functions(self): - for block in self.blocks: - for exp in self._exps: - block.register_matpower(exp) - if self._compute_cholesky: - block.register_cholesky() - if self._compute_cholesky_inverse: - block.register_cholesky_inverse() - - def _finalize_layer_collection(self): - self._layers.create_subgraph() - self._layers.check_registration(self.variables) - self._instantiate_factors() - self._register_matrix_functions() - - def create_ops_and_vars_thunks(self, scope=None): - """Create thunks that make the ops and vars on demand. - - This function returns 4 lists of thunks: cov_variable_thunks, - cov_update_thunks, inv_variable_thunks, and inv_update_thunks. - - The length of each list is the number of factors and the i-th element of - each list corresponds to the i-th factor (given by the "factors" property). - - Note that the execution of these thunks must happen in a certain - partial order. The i-th element of cov_variable_thunks must execute - before the i-th element of cov_update_thunks (and also the i-th element - of inv_update_thunks). Similarly, the i-th element of inv_variable_thunks - must execute before the i-th element of inv_update_thunks. - - TL;DR (oversimplified): Execute the thunks according to the order that - they are returned. - - Args: - scope: A string or None. If None it will be set to the name of this - estimator (given by the name property). All thunks will execute inside - of a variable scope of the given name. (Default: None) - Returns: - cov_variable_thunks: A list of thunks that make the cov variables. - cov_update_thunks: A list of thunks that make the cov update ops. - inv_variable_thunks: A list of thunks that make the inv variables. - inv_update_thunks: A list of thunks that make the inv update ops. - """ - self._check_vars_unmade_and_set_made_flag() - - self._finalize_layer_collection() - - scope = self.name if scope is None else scope - - cov_variable_thunks = [ - self._create_cov_variable_thunk(factor, scope) - for factor in self.factors - ] - cov_update_thunks = [ - self._create_cov_update_thunk(factor, scope) for factor in self.factors - ] - inv_variable_thunks = [ - self._create_inv_variable_thunk(factor, scope) - for factor in self.factors - ] - inv_update_thunks = [ - self._create_inv_update_thunk(factor, scope) for factor in self.factors - ] - - return (cov_variable_thunks, cov_update_thunks, - inv_variable_thunks, inv_update_thunks) - - def _create_cov_variable_thunk(self, factor, scope): - """Constructs a covariance variable thunk for a single FisherFactor.""" - - def thunk(): - with variable_scope.variable_scope(scope): - return factor.instantiate_cov_variables() - - return thunk - - def _create_cov_update_thunk(self, factor, scope): - """Constructs a covariance update thunk for a single FisherFactor.""" - - def thunk(): - with variable_scope.variable_scope(scope): - return factor.make_covariance_update_op(self._cov_ema_decay) - - return thunk - - def _create_inv_variable_thunk(self, factor, scope): - """Constructs a inverse variable thunk for a single FisherFactor.""" - - def thunk(): - with variable_scope.variable_scope(scope): - return factor.instantiate_inv_variables() - - return thunk - - def _create_inv_update_thunk(self, factor, scope): - """Constructs an inverse update thunk for a single FisherFactor.""" - - def thunk(): - with variable_scope.variable_scope(scope): - return control_flow_ops.group(factor.make_inverse_update_ops()) - - return thunk - - def _get_grads_lists_gradients(self, tensors): - # Passing in a list of loss values is better than passing in the sum as - # the latter creates unnessesary ops on the default device - grads_flat = gradients_impl.gradients( - self._layers.eval_losses_on_samples(), - nest.flatten(tensors), - colocate_gradients_with_ops=self._colocate_gradients_with_ops) - grads_all = nest.pack_sequence_as(tensors, grads_flat) - return tuple((grad,) for grad in grads_all) - - def _get_grads_lists_empirical(self, tensors): - # Passing in a list of loss values is better than passing in the sum as - # the latter creates unnecessary ops on the default device - grads_flat = gradients_impl.gradients( - self._layers.eval_losses(), - nest.flatten(tensors), - colocate_gradients_with_ops=self._colocate_gradients_with_ops) - grads_all = nest.pack_sequence_as(tensors, grads_flat) - return tuple((grad,) for grad in grads_all) - - def _get_transformed_random_signs(self): - transformed_random_signs = [] - for loss in self._layers.losses: - with tf_ops.colocate_with(self._layers.loss_colocation_ops[loss]): - transformed_random_signs.append( - loss.multiply_fisher_factor( - utils.generate_random_signs(loss.fisher_factor_inner_shape))) - return transformed_random_signs - - def _get_grads_lists_curvature_prop(self, tensors): - loss_inputs = list(loss.inputs for loss in self._layers.losses) - transformed_random_signs = self._get_transformed_random_signs() - grads_flat = gradients_impl.gradients( - nest.flatten(loss_inputs), - nest.flatten(tensors), - grad_ys=nest.flatten(transformed_random_signs), - colocate_gradients_with_ops=self._colocate_gradients_with_ops) - grads_all = nest.pack_sequence_as(tensors, grads_flat) - return tuple((grad,) for grad in grads_all) - - def _get_grads_lists_exact(self, tensors): - """No docstring required.""" - # Loop over all coordinates of all losses. - grads_all = [] - for loss in self._layers.losses: - with tf_ops.colocate_with(self._layers.loss_colocation_ops[loss]): - for index in np.ndindex(*loss.fisher_factor_inner_static_shape[1:]): - transformed_one_hot = loss.multiply_fisher_factor_replicated_one_hot( - index) - grads_flat = gradients_impl.gradients( - loss.inputs, - nest.flatten(tensors), - grad_ys=transformed_one_hot, - colocate_gradients_with_ops=self._colocate_gradients_with_ops) - grads_all.append(nest.pack_sequence_as(tensors, grads_flat)) - return zip(*grads_all) - - -class FisherEstimatorRoundRobin(placement.RoundRobinPlacementMixin, - FisherEstimator): - """Fisher estimator which provides round robin device placement strategy.""" - pass diff --git a/tensorflow/contrib/kfac/python/ops/estimator_lib.py b/tensorflow/contrib/kfac/python/ops/estimator_lib.py deleted file mode 100644 index 9c9fef471f..0000000000 --- a/tensorflow/contrib/kfac/python/ops/estimator_lib.py +++ /dev/null @@ -1,31 +0,0 @@ -# Copyright 2017 The TensorFlow Authors. All Rights Reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -# ============================================================================== -"""Defines the high-level Fisher estimator class.""" - -from __future__ import absolute_import -from __future__ import division -from __future__ import print_function - -# pylint: disable=unused-import,line-too-long,wildcard-import -from tensorflow.contrib.kfac.python.ops.estimator import * -from tensorflow.python.util.all_util import remove_undocumented -# pylint: enable=unused-import,line-too-long,wildcard-import - -_allowed_symbols = [ - 'FisherEstimator', - 'make_fisher_estimator', -] - -remove_undocumented(__name__, allowed_exception_list=_allowed_symbols) diff --git a/tensorflow/contrib/kfac/python/ops/fisher_blocks.py b/tensorflow/contrib/kfac/python/ops/fisher_blocks.py deleted file mode 100644 index 9fa6eb7dcd..0000000000 --- a/tensorflow/contrib/kfac/python/ops/fisher_blocks.py +++ /dev/null @@ -1,1752 +0,0 @@ -# Copyright 2017 The TensorFlow Authors. All Rights Reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -# ============================================================================== -"""FisherBlock definitions. - -This library contains classes for estimating blocks in a model's Fisher -Information matrix. Suppose one has a model that parameterizes a posterior -distribution over 'y' given 'x' with parameters 'params', p(y | x, params). Its -Fisher Information matrix is given by, - - $$F(params) = E[ v(x, y, params) v(x, y, params)^T ]$$ - -where, - - $$v(x, y, params) = (d / d params) log p(y | x, params)$$ - -and the expectation is taken with respect to the data's distribution for 'x' and -the model's posterior distribution for 'y', - - x ~ p(x) - y ~ p(y | x, params) - -""" - -from __future__ import absolute_import -from __future__ import division -from __future__ import print_function - -import abc -import enum # pylint: disable=g-bad-import-order - -import numpy as np -import six - -from tensorflow.contrib.kfac.python.ops import fisher_factors -from tensorflow.contrib.kfac.python.ops import utils -from tensorflow.python.framework import ops -from tensorflow.python.ops import array_ops -from tensorflow.python.ops import math_ops -from tensorflow.python.util import nest - -# For blocks corresponding to convolutional layers, or any type of block where -# the parameters can be thought of as being replicated in time or space, -# we want to adjust the scale of the damping by -# damping /= num_replications ** NORMALIZE_DAMPING_POWER -NORMALIZE_DAMPING_POWER = 1.0 - -# Methods for adjusting damping for FisherBlocks. See -# compute_pi_adjusted_damping() for details. -PI_OFF_NAME = "off" -PI_TRACENORM_NAME = "tracenorm" -PI_TYPE = PI_TRACENORM_NAME - - -def set_global_constants(normalize_damping_power=None, pi_type=None): - """Sets various global constants used by the classes in this module.""" - global NORMALIZE_DAMPING_POWER - global PI_TYPE - - if normalize_damping_power is not None: - NORMALIZE_DAMPING_POWER = normalize_damping_power - - if pi_type is not None: - PI_TYPE = pi_type - - -def normalize_damping(damping, num_replications): - """Normalize damping after adjusting scale by NORMALIZE_DAMPING_POWER.""" - if NORMALIZE_DAMPING_POWER: - return damping / (num_replications ** NORMALIZE_DAMPING_POWER) - return damping - - -def compute_pi_tracenorm(left_cov, right_cov): - r"""Computes the scalar constant pi for Tikhonov regularization/damping. - - $$\pi = \sqrt{ (trace(A) / dim(A)) / (trace(B) / dim(B)) }$$ - See section 6.3 of https://arxiv.org/pdf/1503.05671.pdf for details. - - Args: - left_cov: A LinearOperator object. The left Kronecker factor "covariance". - right_cov: A LinearOperator object. The right Kronecker factor "covariance". - - Returns: - The computed scalar constant pi for these Kronecker Factors (as a Tensor). - """ - # Instead of dividing by the dim of the norm, we multiply by the dim of the - # other norm. This works out the same in the ratio. - left_norm = left_cov.trace() * int(right_cov.domain_dimension) - right_norm = right_cov.trace() * int(left_cov.domain_dimension) - return math_ops.sqrt(left_norm / right_norm) - - -def compute_pi_adjusted_damping(left_cov, right_cov, damping): - - if PI_TYPE == PI_TRACENORM_NAME: - pi = compute_pi_tracenorm(left_cov, right_cov) - return (damping * pi, damping / pi) - - elif PI_TYPE == PI_OFF_NAME: - return (damping, damping) - - -class PackagedFunc(object): - """A Python thunk with a stable ID. - - Enables stable names for lambdas. - """ - - def __init__(self, func, func_id): - """Initializes PackagedFunc. - - Args: - func: a zero-arg Python function. - func_id: a hashable, function that produces a hashable, or a list/tuple - thereof. - """ - self._func = func - func_id = func_id if isinstance(func_id, (tuple, list)) else (func_id,) - self._func_id = func_id - - def __call__(self): - return self._func() - - @property - def func_id(self): - """A hashable identifier for this function.""" - return tuple(elt() if callable(elt) else elt for elt in self._func_id) - - -def _package_func(func, func_id): - return PackagedFunc(func, func_id) - - -@six.add_metaclass(abc.ABCMeta) -class FisherBlock(object): - """Abstract base class for objects modeling approximate Fisher matrix blocks. - - Subclasses must implement register_matpower, multiply_matpower, - instantiate_factors, tensors_to_compute_grads, and num_registered_towers - methods. - """ - - def __init__(self, layer_collection): - self._layer_collection = layer_collection - - @abc.abstractmethod - def instantiate_factors(self, grads_list, damping): - """Creates and registers the component factors of this Fisher block. - - Args: - grads_list: A list gradients (each a Tensor or tuple of Tensors) with - respect to the tensors returned by tensors_to_compute_grads() that - are to be used to estimate the block. - damping: The damping factor (float or Tensor). - """ - pass - - @abc.abstractmethod - def register_matpower(self, exp): - """Registers a matrix power to be computed by the block. - - Args: - exp: A float representing the power to raise the block by. - """ - pass - - @abc.abstractmethod - def register_cholesky(self): - """Registers a Cholesky factor to be computed by the block.""" - pass - - @abc.abstractmethod - def register_cholesky_inverse(self): - """Registers an inverse Cholesky factor to be computed by the block.""" - pass - - def register_inverse(self): - """Registers a matrix inverse to be computed by the block.""" - self.register_matpower(-1) - - @abc.abstractmethod - def multiply_matpower(self, vector, exp): - """Multiplies the vector by the (damped) matrix-power of the block. - - Args: - vector: The vector (a Tensor or tuple of Tensors) to be multiplied. - exp: A float representing the power to raise the block by before - multiplying it by the vector. - - Returns: - The vector left-multiplied by the (damped) matrix-power of the block. - """ - pass - - def multiply_inverse(self, vector): - """Multiplies the vector by the (damped) inverse of the block. - - Args: - vector: The vector (a Tensor or tuple of Tensors) to be multiplied. - - Returns: - The vector left-multiplied by the (damped) inverse of the block. - """ - return self.multiply_matpower(vector, -1) - - def multiply(self, vector): - """Multiplies the vector by the (damped) block. - - Args: - vector: The vector (a Tensor or tuple of Tensors) to be multiplied. - - Returns: - The vector left-multiplied by the (damped) block. - """ - return self.multiply_matpower(vector, 1) - - @abc.abstractmethod - def multiply_cholesky(self, vector, transpose=False): - """Multiplies the vector by the (damped) Cholesky-factor of the block. - - Args: - vector: The vector (a Tensor or tuple of Tensors) to be multiplied. - transpose: Bool. If true the Cholesky factor is transposed before - multiplying the vector. (Default: False) - - Returns: - The vector left-multiplied by the (damped) Cholesky-factor of the block. - """ - pass - - @abc.abstractmethod - def multiply_cholesky_inverse(self, vector, transpose=False): - """Multiplies vector by the (damped) inverse Cholesky-factor of the block. - - Args: - vector: The vector (a Tensor or tuple of Tensors) to be multiplied. - transpose: Bool. If true the Cholesky factor inverse is transposed - before multiplying the vector. (Default: False) - Returns: - Vector left-multiplied by (damped) inverse Cholesky-factor of the block. - """ - pass - - @abc.abstractmethod - def tensors_to_compute_grads(self): - """Returns the Tensor(s) with respect to which this FisherBlock needs grads. - """ - pass - - @abc.abstractproperty - def num_registered_towers(self): - """Number of towers registered for this FisherBlock. - - Typically equal to the number of towers in a multi-tower setup. - """ - pass - - -class FullFB(FisherBlock): - """FisherBlock using a full matrix estimate (no approximations). - - FullFB uses a full matrix estimate (no approximations), and should only ever - be used for very low dimensional parameters. - - Note that this uses the naive "square the sum estimator", and so is applicable - to any type of parameter in principle, but has very high variance. - """ - - def __init__(self, layer_collection, params): - """Creates a FullFB block. - - Args: - layer_collection: The collection of all layers in the K-FAC approximate - Fisher information matrix to which this FisherBlock belongs. - params: The parameters of this layer (Tensor or tuple of Tensors). - """ - self._batch_sizes = [] - self._params = params - - super(FullFB, self).__init__(layer_collection) - - def instantiate_factors(self, grads_list, damping): - self._damping_func = _package_func(lambda: damping, (damping,)) - - self._factor = self._layer_collection.make_or_get_factor( - fisher_factors.FullFactor, (grads_list, self._batch_size)) - - def register_matpower(self, exp): - self._factor.register_matpower(exp, self._damping_func) - - def register_cholesky(self): - self._factor.register_cholesky(self._damping_func) - - def register_cholesky_inverse(self): - self._factor.register_cholesky_inverse(self._damping_func) - - def _multiply_matrix(self, matrix, vector, transpose=False): - vector_flat = utils.tensors_to_column(vector) - out_flat = matrix.matmul(vector_flat, adjoint=transpose) - return utils.column_to_tensors(vector, out_flat) - - def multiply_matpower(self, vector, exp): - matrix = self._factor.get_matpower(exp, self._damping_func) - return self._multiply_matrix(matrix, vector) - - def multiply_cholesky(self, vector, transpose=False): - matrix = self._factor.get_cholesky(self._damping_func) - return self._multiply_matrix(matrix, vector, transpose=transpose) - - def multiply_cholesky_inverse(self, vector, transpose=False): - matrix = self._factor.get_cholesky_inverse(self._damping_func) - return self._multiply_matrix(matrix, vector, transpose=transpose) - - def full_fisher_block(self): - """Explicitly constructs the full Fisher block.""" - return self._factor.get_cov_as_linear_operator().to_dense() - - def tensors_to_compute_grads(self): - return self._params - - def register_additional_tower(self, batch_size): - """Register an additional tower. - - Args: - batch_size: The batch size, used in the covariance estimator. - """ - self._batch_sizes.append(batch_size) - - @property - def num_registered_towers(self): - return len(self._batch_sizes) - - @property - def _batch_size(self): - return math_ops.reduce_sum(self._batch_sizes) - - -@six.add_metaclass(abc.ABCMeta) -class DiagonalFB(FisherBlock): - """A base class for FisherBlocks that use diagonal approximations.""" - - def register_matpower(self, exp): - # Not needed for this. Matrix powers are computed on demand in the - # diagonal case - pass - - def register_cholesky(self): - # Not needed for this. Cholesky's are computed on demand in the - # diagonal case - pass - - def register_cholesky_inverse(self): - # Not needed for this. Cholesky inverses's are computed on demand in the - # diagonal case - pass - - def _multiply_matrix(self, matrix, vector): - vector_flat = utils.tensors_to_column(vector) - out_flat = matrix.matmul(vector_flat) - return utils.column_to_tensors(vector, out_flat) - - def multiply_matpower(self, vector, exp): - matrix = self._factor.get_matpower(exp, self._damping_func) - return self._multiply_matrix(matrix, vector) - - def multiply_cholesky(self, vector, transpose=False): - matrix = self._factor.get_cholesky(self._damping_func) - return self._multiply_matrix(matrix, vector) - - def multiply_cholesky_inverse(self, vector, transpose=False): - matrix = self._factor.get_cholesky_inverse(self._damping_func) - return self._multiply_matrix(matrix, vector) - - def full_fisher_block(self): - return self._factor.get_cov_as_linear_operator().to_dense() - - -class NaiveDiagonalFB(DiagonalFB): - """FisherBlock using a diagonal matrix approximation. - - This type of approximation is generically applicable but quite primitive. - - Note that this uses the naive "square the sum estimator", and so is applicable - to any type of parameter in principle, but has very high variance. - """ - - def __init__(self, layer_collection, params): - """Creates a NaiveDiagonalFB block. - - Args: - layer_collection: The collection of all layers in the K-FAC approximate - Fisher information matrix to which this FisherBlock belongs. - params: The parameters of this layer (Tensor or tuple of Tensors). - """ - self._params = params - self._batch_sizes = [] - - super(NaiveDiagonalFB, self).__init__(layer_collection) - - def instantiate_factors(self, grads_list, damping): - self._damping_func = _package_func(lambda: damping, (damping,)) - - self._factor = self._layer_collection.make_or_get_factor( - fisher_factors.NaiveDiagonalFactor, (grads_list, self._batch_size)) - - def tensors_to_compute_grads(self): - return self._params - - def register_additional_tower(self, batch_size): - """Register an additional tower. - - Args: - batch_size: The batch size, used in the covariance estimator. - """ - self._batch_sizes.append(batch_size) - - @property - def num_registered_towers(self): - return len(self._batch_sizes) - - @property - def _batch_size(self): - return math_ops.reduce_sum(self._batch_sizes) - - -class InputOutputMultiTower(object): - """Mix-in class for blocks with inputs & outputs and multiple mini-batches.""" - - def __init__(self, *args, **kwargs): - self.__inputs = [] - self.__outputs = [] - super(InputOutputMultiTower, self).__init__(*args, **kwargs) - - def _process_data(self, grads_list): - """Process data into the format used by the factors. - - This function takes inputs and grads_lists data and processes it into - one of the formats expected by the FisherFactor classes (depending on - the value of the global configuration variable TOWER_STRATEGY). - - The initial format of self._inputs is expected to be a list of Tensors - over towers. Similarly grads_lists is expected to be a list over sources - of such lists. - - If TOWER_STRATEGY is "concat", 'inputs' becomes a tuple containing a single - tensor (represented as a PartitionedTensor object) equal to the - concatenation (across towers) of all of the elements of self._inputs. And - similarly grads_list is formatted into a tuple (over sources) of such - tensors (also represented as PartitionedTensors). - - If TOWER_STRATEGY is "separate", formatting of inputs and grads_list - remains unchanged from the initial format (although possibly converting - from lists into tuples). - - Args: - grads_list: grads_list in its initial format (see above). - - Returns: - inputs: self._inputs transformed into the appropriate format (see - above). - grads_list: grads_list transformed into the appropriate format (see - above). - - Raises: - ValueError: if TOWER_STRATEGY is not one of "separate" or "concat". - """ - inputs = self._inputs - # inputs is a list over towers of Tensors - # grads_list is a list of list with the first index being sources and the - # second being towers. - if fisher_factors.TOWER_STRATEGY == "concat": - # Merge towers together into a PartitionedTensor. We package it in - # a singleton tuple since the factors will expect a list over towers - inputs = (utils.PartitionedTensor(inputs),) - # Do the same for grads_list but preserve leading sources dimension - grads_list = tuple((utils.PartitionedTensor(grads),) - for grads in grads_list) - elif fisher_factors.TOWER_STRATEGY == "separate": - inputs = tuple(inputs) - grads_list = tuple(grads_list) - - else: - raise ValueError("Global config variable TOWER_STRATEGY must be one of " - "'concat' or 'separate'.") - - return inputs, grads_list - - def tensors_to_compute_grads(self): - """Tensors to compute derivative of loss with respect to.""" - return tuple(self._outputs) - - def register_additional_tower(self, inputs, outputs): - self._inputs.append(inputs) - self._outputs.append(outputs) - - @property - def num_registered_towers(self): - result = len(self._inputs) - assert result == len(self._outputs) - return result - - @property - def _inputs(self): - return self.__inputs - - @property - def _outputs(self): - return self.__outputs - - -class FullyConnectedDiagonalFB(InputOutputMultiTower, DiagonalFB): - """FisherBlock for fully-connected (dense) layers using a diagonal approx. - - Estimates the Fisher Information matrix's diagonal entries for a fully - connected layer. Unlike NaiveDiagonalFB this uses the low-variance "sum of - squares" estimator. - - Let 'params' be a vector parameterizing a model and 'i' an arbitrary index - into it. We are interested in Fisher(params)[i, i]. This is, - - $$Fisher(params)[i, i] = E[ v(x, y, params) v(x, y, params)^T ][i, i] - = E[ v(x, y, params)[i] ^ 2 ]$$ - - Consider fully connected layer in this model with (unshared) weight matrix - 'w'. For an example 'x' that produces layer inputs 'a' and output - preactivations 's', - - $$v(x, y, w) = vec( a (d loss / d s)^T )$$ - - This FisherBlock tracks Fisher(params)[i, i] for all indices 'i' corresponding - to the layer's parameters 'w'. - """ - - def __init__(self, layer_collection, has_bias=False): - """Creates a FullyConnectedDiagonalFB block. - - Args: - layer_collection: The collection of all layers in the K-FAC approximate - Fisher information matrix to which this FisherBlock belongs. - has_bias: Whether the component Kronecker factors have an additive bias. - (Default: False) - """ - self._has_bias = has_bias - - super(FullyConnectedDiagonalFB, self).__init__(layer_collection) - - def instantiate_factors(self, grads_list, damping): - inputs, grads_list = self._process_data(grads_list) - - self._factor = self._layer_collection.make_or_get_factor( - fisher_factors.FullyConnectedDiagonalFactor, - (inputs, grads_list, self._has_bias)) - - self._damping_func = _package_func(lambda: damping, (damping,)) - - -class ConvDiagonalFB(InputOutputMultiTower, DiagonalFB): - """FisherBlock for 2-D convolutional layers using a diagonal approx. - - Estimates the Fisher Information matrix's diagonal entries for a convolutional - layer. Unlike NaiveDiagonalFB this uses the low-variance "sum of squares" - estimator. - - Let 'params' be a vector parameterizing a model and 'i' an arbitrary index - into it. We are interested in Fisher(params)[i, i]. This is, - - $$Fisher(params)[i, i] = E[ v(x, y, params) v(x, y, params)^T ][i, i] - = E[ v(x, y, params)[i] ^ 2 ]$$ - - Consider a convoluational layer in this model with (unshared) filter matrix - 'w'. For an example image 'x' that produces layer inputs 'a' and output - preactivations 's', - - $$v(x, y, w) = vec( sum_{loc} a_{loc} (d loss / d s_{loc})^T )$$ - - where 'loc' is a single (x, y) location in an image. - - This FisherBlock tracks Fisher(params)[i, i] for all indices 'i' corresponding - to the layer's parameters 'w'. - """ - - def __init__(self, - layer_collection, - params, - strides, - padding, - data_format=None, - dilations=None): - """Creates a ConvDiagonalFB block. - - Args: - layer_collection: The collection of all layers in the K-FAC approximate - Fisher information matrix to which this FisherBlock belongs. - params: The parameters (Tensor or tuple of Tensors) of this layer. If - kernel alone, a Tensor of shape [kernel_height, kernel_width, - in_channels, out_channels]. If kernel and bias, a tuple of 2 elements - containing the previous and a Tensor of shape [out_channels]. - strides: The stride size in this layer (1-D Tensor of length 4). - padding: The padding in this layer (e.g. "SAME"). - data_format: str or None. Format of input data. - dilations: List of 4 ints or None. Rate for dilation along all dimensions. - - Raises: - ValueError: if strides is not length-4. - ValueError: if dilations is not length-4. - ValueError: if channel is not last dimension. - """ - if len(strides) != 4: - raise ValueError("strides must contain 4 numbers.") - - if dilations is None: - dilations = [1, 1, 1, 1] - - if len(dilations) != 4: - raise ValueError("dilations must contain 4 numbers.") - - if not utils.is_data_format_channel_last(data_format): - raise ValueError("data_format must be channels-last.") - - self._strides = maybe_tuple(strides) - self._padding = padding - self._data_format = data_format - self._dilations = maybe_tuple(dilations) - self._has_bias = isinstance(params, (tuple, list)) - - fltr = params[0] if self._has_bias else params - self._filter_shape = tuple(fltr.shape.as_list()) - - if len(self._filter_shape) != 4: - raise ValueError( - "Convolution filter must be of shape" - " [filter_height, filter_width, in_channels, out_channels].") - - super(ConvDiagonalFB, self).__init__(layer_collection) - - def instantiate_factors(self, grads_list, damping): - inputs, grads_list = self._process_data(grads_list) - - # Infer number of locations upon which convolution is applied. - self._num_locations = num_conv_locations(inputs[0].shape.as_list(), - self._strides) - - self._factor = self._layer_collection.make_or_get_factor( - fisher_factors.ConvDiagonalFactor, - (inputs, grads_list, self._filter_shape, self._strides, self._padding, - self._data_format, self._dilations, self._has_bias)) - - def damping_func(): - return self._num_locations * normalize_damping(damping, - self._num_locations) - - damping_id = (self._num_locations, "mult", "normalize_damping", damping, - self._num_locations) - self._damping_func = _package_func(damping_func, damping_id) - - -class KroneckerProductFB(FisherBlock): - """A base class for blocks with separate input and output Kronecker factors. - - The Fisher block is approximated as a Kronecker product of the input and - output factors. - """ - - def _setup_damping(self, damping, normalization=None): - """Makes functions that compute the damping values for both factors.""" - def compute_damping(): - if normalization is not None: - maybe_normalized_damping = normalize_damping(damping, normalization) - else: - maybe_normalized_damping = damping - - return compute_pi_adjusted_damping( - self._input_factor.get_cov_as_linear_operator(), - self._output_factor.get_cov_as_linear_operator(), - maybe_normalized_damping**0.5) - - if normalization is not None: - damping_id = ("compute_pi_adjusted_damping", - "cov", self._input_factor.name, - "cov", self._output_factor.name, - "normalize_damping", damping, normalization, "power", 0.5) - else: - damping_id = ("compute_pi_adjusted_damping", - "cov", self._input_factor.name, - "cov", self._output_factor.name, - damping, "power", 0.5) - - self._input_damping_func = _package_func(lambda: compute_damping()[0], - damping_id + ("ref", 0)) - self._output_damping_func = _package_func(lambda: compute_damping()[1], - damping_id + ("ref", 1)) - - def register_matpower(self, exp): - self._input_factor.register_matpower(exp, self._input_damping_func) - self._output_factor.register_matpower(exp, self._output_damping_func) - - def register_cholesky(self): - self._input_factor.register_cholesky(self._input_damping_func) - self._output_factor.register_cholesky(self._output_damping_func) - - def register_cholesky_inverse(self): - self._input_factor.register_cholesky_inverse(self._input_damping_func) - self._output_factor.register_cholesky_inverse(self._output_damping_func) - - @property - def _renorm_coeff(self): - """Kronecker factor multiplier coefficient. - - If this FisherBlock is represented as 'FB = c * kron(left, right)', then - this is 'c'. - - Returns: - 0-D Tensor. - """ - return 1.0 - - def _multiply_factored_matrix(self, left_factor, right_factor, vector, - extra_scale=1.0, transpose_left=False, - transpose_right=False): - reshaped_vector = utils.layer_params_to_mat2d(vector) - reshaped_out = right_factor.matmul_right(reshaped_vector, - adjoint=transpose_right) - reshaped_out = left_factor.matmul(reshaped_out, - adjoint=transpose_left) - if extra_scale != 1.0: - reshaped_out *= math_ops.cast(extra_scale, dtype=reshaped_out.dtype) - return utils.mat2d_to_layer_params(vector, reshaped_out) - - def multiply_matpower(self, vector, exp): - left_factor = self._input_factor.get_matpower( - exp, self._input_damping_func) - right_factor = self._output_factor.get_matpower( - exp, self._output_damping_func) - extra_scale = float(self._renorm_coeff)**exp - return self._multiply_factored_matrix(left_factor, right_factor, vector, - extra_scale=extra_scale) - - def multiply_cholesky(self, vector, transpose=False): - left_factor = self._input_factor.get_cholesky(self._input_damping_func) - right_factor = self._output_factor.get_cholesky(self._output_damping_func) - extra_scale = float(self._renorm_coeff)**0.5 - return self._multiply_factored_matrix(left_factor, right_factor, vector, - extra_scale=extra_scale, - transpose_left=transpose, - transpose_right=not transpose) - - def multiply_cholesky_inverse(self, vector, transpose=False): - left_factor = self._input_factor.get_cholesky_inverse( - self._input_damping_func) - right_factor = self._output_factor.get_cholesky_inverse( - self._output_damping_func) - extra_scale = float(self._renorm_coeff)**-0.5 - return self._multiply_factored_matrix(left_factor, right_factor, vector, - extra_scale=extra_scale, - transpose_left=transpose, - transpose_right=not transpose) - - def full_fisher_block(self): - """Explicitly constructs the full Fisher block. - - Used for testing purposes. (In general, the result may be very large.) - - Returns: - The full Fisher block. - """ - left_factor = self._input_factor.get_cov_as_linear_operator().to_dense() - right_factor = self._output_factor.get_cov_as_linear_operator().to_dense() - return self._renorm_coeff * utils.kronecker_product(left_factor, - right_factor) - - -class EmbeddingKFACFB(InputOutputMultiTower, KroneckerProductFB): - """K-FAC FisherBlock for embedding layers. - - This FisherBlock is similar to FullyConnectedKFACBasicFB, except that its - input factor is approximated by a diagonal matrix. In the case that each - example references exactly one embedding, this approximation is exact. - - Does not support bias parameters. - """ - - def __init__(self, layer_collection, vocab_size): - """Creates a EmbeddingKFACFB block. - - Args: - layer_collection: The collection of all layers in the K-FAC approximate - Fisher information matrix to which this FisherBlock belongs. - vocab_size: int. Size of vocabulary for this embedding layer. - """ - self._vocab_size = vocab_size - - super(EmbeddingKFACFB, self).__init__(layer_collection) - - def instantiate_factors(self, grads_list, damping): - """Instantiate Kronecker Factors for this FisherBlock. - - Args: - grads_list: List of list of Tensors. grads_list[i][j] is the - gradient of the loss with respect to 'outputs' from source 'i' and - tower 'j'. Each Tensor has shape [tower_minibatch_size, output_size]. - damping: 0-D Tensor or float. 'damping' * identity is approximately added - to this FisherBlock's Fisher approximation. - """ - inputs, grads_list = self._process_data(grads_list) - - self._input_factor = self._layer_collection.make_or_get_factor( - fisher_factors.EmbeddingInputKroneckerFactor, - (inputs, self._vocab_size)) - self._output_factor = self._layer_collection.make_or_get_factor( - fisher_factors.FullyConnectedKroneckerFactor, (grads_list,)) - self._setup_damping(damping) - - -class FullyConnectedKFACBasicFB(InputOutputMultiTower, KroneckerProductFB): - """K-FAC FisherBlock for fully-connected (dense) layers. - - This uses the Kronecker-factorized approximation from the original - K-FAC paper (https://arxiv.org/abs/1503.05671) - """ - - def __init__(self, layer_collection, has_bias=False): - """Creates a FullyConnectedKFACBasicFB block. - - Args: - layer_collection: The collection of all layers in the K-FAC approximate - Fisher information matrix to which this FisherBlock belongs. - has_bias: Whether the component Kronecker factors have an additive bias. - (Default: False) - """ - self._has_bias = has_bias - - super(FullyConnectedKFACBasicFB, self).__init__(layer_collection) - - def instantiate_factors(self, grads_list, damping): - """Instantiate Kronecker Factors for this FisherBlock. - - Args: - grads_list: List of list of Tensors. grads_list[i][j] is the - gradient of the loss with respect to 'outputs' from source 'i' and - tower 'j'. Each Tensor has shape [tower_minibatch_size, output_size]. - damping: 0-D Tensor or float. 'damping' * identity is approximately added - to this FisherBlock's Fisher approximation. - """ - inputs, grads_list = self._process_data(grads_list) - - self._input_factor = self._layer_collection.make_or_get_factor( - fisher_factors.FullyConnectedKroneckerFactor, - ((inputs,), self._has_bias)) - self._output_factor = self._layer_collection.make_or_get_factor( - fisher_factors.FullyConnectedKroneckerFactor, - (grads_list,)) - self._setup_damping(damping) - - -class ConvKFCBasicFB(InputOutputMultiTower, KroneckerProductFB): - r"""FisherBlock for convolutional layers using the basic KFC approx. - - Estimates the Fisher Information matrix's blog for a convolutional - layer. - - Consider a convolutional layer in this model with (unshared) filter matrix - 'w'. For a minibatch that produces inputs 'a' and output preactivations 's', - this FisherBlock estimates, - - $$F(w) = \#locations * kronecker(E[flat(a) flat(a)^T], - E[flat(ds) flat(ds)^T])$$ - - where - - $$ds = (d / ds) log p(y | x, w)$$ - #locations = number of (x, y) locations where 'w' is applied. - - where the expectation is taken over all examples and locations and flat() - concatenates an array's leading dimensions. - - See equation 23 in https://arxiv.org/abs/1602.01407 for details. - """ - - def __init__(self, - layer_collection, - params, - padding, - strides=None, - dilation_rate=None, - data_format=None, - extract_patches_fn=None): - """Creates a ConvKFCBasicFB block. - - Args: - layer_collection: The collection of all layers in the K-FAC approximate - Fisher information matrix to which this FisherBlock belongs. - params: The parameters (Tensor or tuple of Tensors) of this layer. If - kernel alone, a Tensor of shape [..spatial_filter_shape.., - in_channels, out_channels]. If kernel and bias, a tuple of 2 elements - containing the previous and a Tensor of shape [out_channels]. - padding: str. Padding method. - strides: List of ints or None. Contains [..spatial_filter_strides..] if - 'extract_patches_fn' is compatible with tf.nn.convolution(), else - [1, ..spatial_filter_strides, 1]. - dilation_rate: List of ints or None. Rate for dilation along each spatial - dimension if 'extract_patches_fn' is compatible with - tf.nn.convolution(), else [1, ..spatial_dilation_rates.., 1]. - data_format: str or None. Format of input data. - extract_patches_fn: str or None. Name of function that extracts image - patches. One of "extract_convolution_patches", "extract_image_patches", - "extract_pointwise_conv2d_patches". - """ - self._padding = padding - self._strides = maybe_tuple(strides) - self._dilation_rate = maybe_tuple(dilation_rate) - self._data_format = data_format - self._extract_patches_fn = extract_patches_fn - self._has_bias = isinstance(params, (tuple, list)) - - fltr = params[0] if self._has_bias else params - self._filter_shape = tuple(fltr.shape.as_list()) - - super(ConvKFCBasicFB, self).__init__(layer_collection) - - def instantiate_factors(self, grads_list, damping): - inputs, grads_list = self._process_data(grads_list) - - # Infer number of locations upon which convolution is applied. - self._num_locations = num_conv_locations(inputs[0].shape.as_list(), - self._strides) - - self._input_factor = self._layer_collection.make_or_get_factor( - fisher_factors.ConvInputKroneckerFactor, - (inputs, self._filter_shape, self._padding, self._strides, - self._dilation_rate, self._data_format, self._extract_patches_fn, - self._has_bias)) - self._output_factor = self._layer_collection.make_or_get_factor( - fisher_factors.ConvOutputKroneckerFactor, (grads_list,)) - - self._setup_damping(damping, normalization=self._num_locations) - - @property - def _renorm_coeff(self): - return self._num_locations - - -class DepthwiseConvDiagonalFB(ConvDiagonalFB): - """FisherBlock for depthwise_conv2d(). - - Equivalent to ConvDiagonalFB applied to each input channel in isolation. - """ - - def __init__(self, - layer_collection, - params, - strides, - padding, - rate=None, - data_format=None): - """Creates a DepthwiseConvKFCBasicFB block. - - Args: - layer_collection: The collection of all layers in the K-FAC approximate - Fisher information matrix to which this FisherBlock belongs. - params: Tensor of shape [filter_height, filter_width, in_channels, - channel_multiplier]. - strides: List of 4 ints. Strides along all dimensions. - padding: str. Padding method. - rate: List of 4 ints or None. Rate for dilation along all dimensions. - data_format: str or None. Format of input data. - - Raises: - NotImplementedError: If parameters contains bias. - ValueError: If filter is not 4-D. - ValueError: If strides is not length-4. - ValueError: If rates is not length-2. - ValueError: If channels are not last dimension. - """ - if isinstance(params, (tuple, list)): - raise NotImplementedError("Bias not yet supported.") - - if params.shape.ndims != 4: - raise ValueError("Filter must be 4-D.") - - if len(strides) != 4: - raise ValueError("strides must account for 4 dimensions.") - - if rate is not None: - if len(rate) != 2: - raise ValueError("rate must only account for spatial dimensions.") - rate = [1, rate[0], rate[1], 1] # conv2d expects 4-element rate. - - if not utils.is_data_format_channel_last(data_format): - raise ValueError("data_format must be channels-last.") - - super(DepthwiseConvDiagonalFB, self).__init__( - layer_collection=layer_collection, - params=params, - strides=strides, - padding=padding, - dilations=rate, - data_format=data_format) - - # This is a hack to overwrite the same setting in ConvKFCBasicFB.__init__(). - filter_height, filter_width, in_channels, channel_multiplier = ( - params.shape.as_list()) - self._filter_shape = (filter_height, filter_width, in_channels, - in_channels * channel_multiplier) - - def _multiply_matrix(self, matrix, vector): - conv2d_vector = depthwise_conv2d_filter_to_conv2d_filter(vector) - conv2d_result = super( - DepthwiseConvDiagonalFB, self)._multiply_matrix(matrix, conv2d_vector) - return conv2d_filter_to_depthwise_conv2d_filter(conv2d_result) - - -class DepthwiseConvKFCBasicFB(ConvKFCBasicFB): - """FisherBlock for depthwise_conv2d(). - - Equivalent to ConvKFCBasicFB applied to each input channel in isolation. - """ - - def __init__(self, - layer_collection, - params, - strides, - padding, - rate=None, - data_format=None): - """Creates a DepthwiseConvKFCBasicFB block. - - Args: - layer_collection: The collection of all layers in the K-FAC approximate - Fisher information matrix to which this FisherBlock belongs. - params: Tensor of shape [filter_height, filter_width, in_channels, - channel_multiplier]. - strides: List of 4 ints. Strides along all dimensions. - padding: str. Padding method. - rate: List of 4 ints or None. Rate for dilation along all dimensions. - data_format: str or None. Format of input data. - - Raises: - NotImplementedError: If parameters contains bias. - ValueError: If filter is not 4-D. - ValueError: If strides is not length-4. - ValueError: If rates is not length-2. - ValueError: If channels are not last dimension. - """ - if isinstance(params, (tuple, list)): - raise NotImplementedError("Bias not yet supported.") - - if params.shape.ndims != 4: - raise ValueError("Filter must be 4-D.") - - if len(strides) != 4: - raise ValueError("strides must account for 4 dimensions.") - - if rate is not None: - if len(rate) != 2: - raise ValueError("rate must only account for spatial dimensions.") - rate = [1, rate[0], rate[1], 1] # conv2d expects 4-element rate. - - if not utils.is_data_format_channel_last(data_format): - raise ValueError("data_format must be channels-last.") - - super(DepthwiseConvKFCBasicFB, self).__init__( - layer_collection=layer_collection, - params=params, - padding=padding, - strides=strides, - dilation_rate=rate, - data_format=data_format, - extract_patches_fn="extract_image_patches") - - # This is a hack to overwrite the same setting in ConvKFCBasicFB.__init__(). - filter_height, filter_width, in_channels, channel_multiplier = ( - params.shape.as_list()) - self._filter_shape = (filter_height, filter_width, in_channels, - in_channels * channel_multiplier) - - def _multiply_factored_matrix(self, left_factor, right_factor, vector, - extra_scale=1.0, transpose_left=False, - transpose_right=False): - conv2d_vector = depthwise_conv2d_filter_to_conv2d_filter(vector) - conv2d_result = super( - DepthwiseConvKFCBasicFB, self)._multiply_factored_matrix( - left_factor, right_factor, conv2d_vector, extra_scale=extra_scale, - transpose_left=transpose_left, transpose_right=transpose_right) - return conv2d_filter_to_depthwise_conv2d_filter(conv2d_result) - - -def depthwise_conv2d_filter_to_conv2d_filter(filter, name=None): # pylint: disable=redefined-builtin - """Converts a convolution filter for use with conv2d. - - Transforms a filter for use with tf.nn.depthwise_conv2d() to one that's - compatible with tf.nn.conv2d(). - - Args: - filter: Tensor of shape [height, width, in_channels, channel_multiplier]. - name: None or str. Name of Op. - - Returns: - Tensor of shape [height, width, in_channels, out_channels]. - - """ - with ops.name_scope(name, "depthwise_conv2d_filter_to_conv2d_filter", - [filter]): - filter = ops.convert_to_tensor(filter) - filter_height, filter_width, in_channels, channel_multiplier = ( - filter.shape.as_list()) - - results = [] - for i in range(in_channels): - # Slice out one in_channel's filter. Insert zeros around it to force it - # to affect that channel and that channel alone. - elements = [] - if i > 0: - elements.append( - array_ops.zeros( - [filter_height, filter_width, i, channel_multiplier])) - elements.append(filter[:, :, i:(i + 1), :]) - if i + 1 < in_channels: - elements.append( - array_ops.zeros([ - filter_height, filter_width, in_channels - (i + 1), - channel_multiplier - ])) - - # Concat along in_channel. - results.append( - array_ops.concat(elements, axis=-2, name="in_channel_%d" % i)) - - # Concat along out_channel. - return array_ops.concat(results, axis=-1, name="out_channel") - - -def conv2d_filter_to_depthwise_conv2d_filter(filter, name=None): # pylint: disable=redefined-builtin - """Converts a convolution filter for use with depthwise_conv2d. - - Transforms a filter for use with tf.nn.conv2d() to one that's - compatible with tf.nn.depthwise_conv2d(). Ignores all filters but those along - the diagonal. - - Args: - filter: Tensor of shape [height, width, in_channels, out_channels]. - name: None or str. Name of Op. - - Returns: - Tensor of shape, - [height, width, in_channels, channel_multiplier] - - Raises: - ValueError: if out_channels is not evenly divisible by in_channels. - """ - with ops.name_scope(name, "conv2d_filter_to_depthwise_conv2d_filter", - [filter]): - filter = ops.convert_to_tensor(filter) - filter_height, filter_width, in_channels, out_channels = ( - filter.shape.as_list()) - - if out_channels % in_channels != 0: - raise ValueError("out_channels must be evenly divisible by in_channels.") - channel_multiplier = out_channels // in_channels - - results = [] - filter = array_ops.reshape(filter, [ - filter_height, filter_width, in_channels, in_channels, - channel_multiplier - ]) - for i in range(in_channels): - # Slice out output corresponding to the correct filter. - filter_slice = array_ops.reshape( - filter[:, :, i, i, :], - [filter_height, filter_width, 1, channel_multiplier]) - results.append(filter_slice) - - # Concat along out_channel. - return array_ops.concat(results, axis=-2, name="in_channels") - - -def maybe_tuple(obj): - if not isinstance(obj, list): - return obj - return tuple(obj) - - -def num_conv_locations(input_shape, strides): - """Returns the number of spatial locations a 2D Conv kernel is applied to. - - Args: - input_shape: List of ints representing shape of inputs to - tf.nn.convolution(). - strides: List of ints representing strides along spatial dimensions as - passed in to tf.nn.convolution(). - - Returns: - A scalar |T| denoting the number of spatial locations for the Conv layer. - """ - spatial_input_locations = np.prod(input_shape[1:-1]) - - if strides is None: - spatial_strides_divisor = 1 - else: - spatial_strides_divisor = np.prod(strides) - - return spatial_input_locations // spatial_strides_divisor - - -class InputOutputMultiTowerMultiUse(InputOutputMultiTower): - """Adds methods for multi-use/time-step case to InputOutputMultiTower.""" - - def __init__(self, num_uses=None, *args, **kwargs): - self._num_uses = num_uses - super(InputOutputMultiTowerMultiUse, self).__init__(*args, **kwargs) - - def _process_data(self, grads_list): - """Process temporal/multi-use data into the format used by the factors. - - This function takes inputs and grads_lists data and processes it into - one of the formats expected by the FisherFactor classes (depending on - the value of the global configuration variable TOWER_STRATEGY). - - It accepts the data in one of two initial formats. The first possible - format is where self._inputs is a list of list of Tensors. The first index - is tower, the second is use/time-step. grads_list, meanwhile, is a list - over sources of such lists of lists. - - The second possible data format is where self._inputs is a Tensor with - uses/times-steps folded into the batch dimension. i.e. it is a Tensor - of shape [num_uses * size_batch, ...] which represents a reshape of a - Tensor of shape [num_uses, size_batch, ...]. And similarly grads_list is - a list over sources of such Tensors. - - There are two possible formats which inputs and grads_list are transformed - into. - - If TOWER_STRATEGY is "concat", 'inputs' becomes a tuple containing - a single tensor (represented as a PartitionedTensor object) with all of - the data from the towers, as well as the uses/time-steps, concatenated - together. In this tensor the leading dimension is the batch and - use/time-step dimensions folded together (with 'use' being the major of - these two, so that the tensors can be thought of as reshapes of ones of - shape [num_uses, batch_size, ...]). grads_list is similarly formatted as a - tuple over sources of such tensors. - - If TOWER_STRATEGY is "separate" the inputs are formatted into lists of - tensors over towers. Each of these tensors has a similar format to - the tensor produced by the "concat" option, except that each contains - only the data from a single tower. grads_list is similarly formatted - into a tuple over sources of such tuples. - - Args: - grads_list: grads_list in its initial format (see above). - - Returns: - inputs: self._inputs transformed into the appropriate format (see - above). - grads_list: grads_list transformed into the appropriate format (see - above). - - Raises: - ValueError: If TOWER_STRATEGY is not one of "separate" or "concat". - ValueError: If the given/initial format of self._inputs and grads_list - isn't recognized, or doesn't agree with self._num_uses. - """ - - inputs = self._inputs - - if isinstance(inputs[0], (list, tuple)): - num_uses = len(inputs[0]) - if self._num_uses is not None and self._num_uses != num_uses: - raise ValueError("num_uses argument doesn't match length of inputs.") - else: - self._num_uses = num_uses - - # Check that all mini-batches/towers have the same number of uses - if not all(len(input_) == num_uses for input_ in inputs): - raise ValueError("Length of inputs argument is inconsistent across " - "towers.") - - if fisher_factors.TOWER_STRATEGY == "concat": - # Reverse the tower and use/time-step indices, so that use is now first, - # and towers is second - inputs = tuple(zip(*inputs)) - - # Flatten the two dimensions - inputs = nest.flatten(inputs) - - # Merge everything together into a PartitionedTensor. We package it in - # a singleton tuple since the factors will expect a list over towers - inputs = (utils.PartitionedTensor(inputs),) - - elif fisher_factors.TOWER_STRATEGY == "separate": - # Merge together the uses/time-step dimension into PartitionedTensors, - # but keep the leading dimension (towers) intact for the factors to - # process individually. - inputs = tuple(utils.PartitionedTensor(input_) for input_ in inputs) - - else: - raise ValueError("Global config variable TOWER_STRATEGY must be one of " - "'concat' or 'separate'.") - else: - inputs = tuple(inputs) - - # Now we perform the analogous processing for grads_list - if isinstance(grads_list[0][0], (list, tuple)): - num_uses = len(grads_list[0][0]) - if self._num_uses is not None and self._num_uses != num_uses: - raise ValueError("num_uses argument doesn't match length of outputs, " - "or length of outputs is inconsistent with length of " - "inputs.") - else: - self._num_uses = num_uses - - if not all(len(grad) == num_uses for grads in grads_list - for grad in grads): - raise ValueError("Length of outputs argument is inconsistent across " - "towers.") - - if fisher_factors.TOWER_STRATEGY == "concat": - # Reverse the tower and use/time-step indices, so that use is now first, - # and towers is second - grads_list = tuple(tuple(zip(*grads)) for grads in grads_list) - - # Flatten the two dimensions, leaving the leading dimension (source) - # intact - grads_list = tuple(nest.flatten(grads) for grads in grads_list) - - # Merge inner dimensions together into PartitionedTensors. We package - # them in a singleton tuple since the factors will expect a list over - # towers - grads_list = tuple((utils.PartitionedTensor(grads),) - for grads in grads_list) - - elif fisher_factors.TOWER_STRATEGY == "separate": - # Merge together the uses/time-step dimension into PartitionedTensors, - # but keep the leading dimension (towers) intact for the factors to - # process individually. - grads_list = tuple(tuple(utils.PartitionedTensor(grad) - for grad in grads) - for grads in grads_list) - - else: - raise ValueError("Global config variable TOWER_STRATEGY must be one of " - "'concat' or 'separate'.") - else: - grads_list = tuple(tuple(grads) for grads in grads_list) - - if self._num_uses is None: - raise ValueError("You must supply a value for the num_uses argument if " - "the number of uses cannot be inferred from inputs or " - "outputs arguments (e.g. if they are both given in the " - "single Tensor format, instead of as lists of Tensors.") - - return inputs, grads_list - - -class FullyConnectedMultiIndepFB(InputOutputMultiTowerMultiUse, - KroneckerProductFB): - """FisherBlock for fully-connected layers that share parameters. - - This class implements the "independence across time" approximation from the - following paper: - https://openreview.net/pdf?id=HyMTkQZAb - """ - - def __init__(self, layer_collection, has_bias=False, num_uses=None): - """Creates a FullyConnectedMultiIndepFB block. - - Args: - layer_collection: LayerCollection instance. - has_bias: bool. If True, estimates Fisher with respect to a bias - parameter as well as the layer's parameters. - num_uses: int or None. Number of uses of the layer in the model's graph. - Only required if the data is formatted with uses/time folded into the - batch dimension (instead of uses/time being a list dimension). - (Default: None) - """ - self._has_bias = has_bias - - super(FullyConnectedMultiIndepFB, self).__init__( - layer_collection=layer_collection, - num_uses=num_uses) - - def instantiate_factors(self, grads_list, damping): - inputs, grads_list = self._process_data(grads_list) - - self._input_factor = self._layer_collection.make_or_get_factor( - fisher_factors.FullyConnectedMultiKF, - ((inputs,), self._num_uses, self._has_bias)) - - self._output_factor = self._layer_collection.make_or_get_factor( - fisher_factors.FullyConnectedMultiKF, (grads_list, self._num_uses)) - - self._setup_damping(damping, normalization=self._num_uses) - - @property - def _renorm_coeff(self): - return float(self._num_uses) - - -class ConvKFCBasicMultiIndepFB(InputOutputMultiTowerMultiUse, - KroneckerProductFB): - """FisherBlock for 2D convolutional layers using the basic KFC approx. - - Similar to ConvKFCBasicFB except that this version supports multiple - uses/time-steps via a standard independence approximation. Similar to the - "independence across time" used in FullyConnectedMultiIndepFB but generalized - in the obvious way to conv layers. - """ - - def __init__(self, - layer_collection, - params, - padding, - strides=None, - dilation_rate=None, - data_format=None, - extract_patches_fn=None, - num_uses=None): - """Creates a ConvKFCBasicMultiIndepFB block. - - Args: - layer_collection: The collection of all layers in the K-FAC approximate - Fisher information matrix to which this FisherBlock belongs. - params: The parameters (Tensor or tuple of Tensors) of this layer. If - kernel alone, a Tensor of shape [..spatial_filter_shape.., - in_channels, out_channels]. If kernel and bias, a tuple of 2 elements - containing the previous and a Tensor of shape [out_channels]. - padding: str. Padding method. - strides: List of ints or None. Contains [..spatial_filter_strides..] if - 'extract_patches_fn' is compatible with tf.nn.convolution(), else - [1, ..spatial_filter_strides, 1]. - dilation_rate: List of ints or None. Rate for dilation along each spatial - dimension if 'extract_patches_fn' is compatible with - tf.nn.convolution(), else [1, ..spatial_dilation_rates.., 1]. - data_format: str or None. Format of input data. - extract_patches_fn: str or None. Name of function that extracts image - patches. One of "extract_convolution_patches", "extract_image_patches", - "extract_pointwise_conv2d_patches". - num_uses: int or None. Number of uses of the layer in the model's graph. - Only required if the data is formatted with uses/time folded into the - batch dimension (instead of uses/time being a list dimension). - (Default: None) - """ - self._padding = padding - self._strides = maybe_tuple(strides) - self._dilation_rate = maybe_tuple(dilation_rate) - self._data_format = data_format - self._extract_patches_fn = extract_patches_fn - self._has_bias = isinstance(params, (tuple, list)) - - fltr = params[0] if self._has_bias else params - self._filter_shape = tuple(fltr.shape.as_list()) - - super(ConvKFCBasicMultiIndepFB, self).__init__( - layer_collection=layer_collection, - num_uses=num_uses) - - def instantiate_factors(self, grads_list, damping): - inputs, grads_list = self._process_data(grads_list) - - # Infer number of locations upon which convolution is applied. - self._num_locations = num_conv_locations(inputs[0].shape.as_list(), - self._strides) - - self._input_factor = self._layer_collection.make_or_get_factor( - fisher_factors.ConvInputKroneckerFactor, - (inputs, self._filter_shape, self._padding, self._strides, - self._dilation_rate, self._data_format, self._extract_patches_fn, - self._has_bias)) - self._output_factor = self._layer_collection.make_or_get_factor( - fisher_factors.ConvOutputKroneckerFactor, (grads_list,)) - - self._setup_damping(damping, normalization= - (self._num_locations * self._num_uses)) - - @property - def _renorm_coeff(self): - return self._num_locations * self._num_uses - - -class EmbeddingKFACMultiIndepFB(InputOutputMultiTowerMultiUse, - KroneckerProductFB): - """K-FAC FisherBlock for embedding layers used multiple times in the graph. - - Similar to EmbeddingKFACFB except that this version supports multiple uses - of the parameter within a single model. These uses could correspond to time - steps in an RNN architecture, but they don't have to. - - Does not support bias parameters. - """ - - def __init__(self, layer_collection, vocab_size, num_uses=None): - """Creates a EmbeddingKFACMultiIndepFB block. - - Args: - layer_collection: The collection of all layers in the K-FAC approximate - Fisher information matrix to which this FisherBlock belongs. - vocab_size: int. Size of vocabulary for this embedding layer. - num_uses: int or None. Number of uses of the layer in the model's graph. - Only required if the data is formatted with time folded into the batch - dimension (instead of time being a list dimension). (Default: None) - """ - self._vocab_size = vocab_size - - super(EmbeddingKFACMultiIndepFB, self).__init__( - layer_collection=layer_collection, - num_uses=num_uses) - - def instantiate_factors(self, grads_list, damping): - """Instantiate Kronecker Factors for this FisherBlock. - - Args: - grads_list: List of list of list of Tensors. grads_list[i][j][k] is the - gradient of the loss with respect to 'outputs' from source 'i', - tower/mini-batch 'j', and use/time-step 'k'. Each Tensor has shape - [tower_minibatch_size, output_size]. - damping: 0-D Tensor or float. 'damping' * identity is approximately added - to this FisherBlock's Fisher approximation. - """ - inputs, grads_list = self._process_data(grads_list) - - self._input_factor = self._layer_collection.make_or_get_factor( - fisher_factors.EmbeddingInputKroneckerFactor, - (inputs, self._vocab_size)) - self._output_factor = self._layer_collection.make_or_get_factor( - fisher_factors.FullyConnectedMultiKF, (grads_list, self._num_uses)) - self._setup_damping(damping, normalization=self._num_uses) - - @property - def _renorm_coeff(self): - return float(self._num_uses) - - -class SeriesFBApproximation(enum.IntEnum): - """See FullyConnectedSeriesFB.__init__ for description and usage.""" - option1 = 1 - option2 = 2 - - -class FullyConnectedSeriesFB(InputOutputMultiTowerMultiUse, - KroneckerProductFB): - """FisherBlock for fully-connected layers that share parameters across time. - - This class implements the "Option 1" and "Option 2" approximation from the - following paper: - https://openreview.net/pdf?id=HyMTkQZAb - - See the end of the appendix of the paper for a pseudo-code of the - algorithm being implemented by multiply_matpower here. Note that we are - using pre-computed versions of certain matrix-matrix products to speed - things up. This is explicitly explained wherever it is done. - """ - - def __init__(self, - layer_collection, - has_bias=False, - num_uses=None, - option=SeriesFBApproximation.option2): - """Constructs a new `FullyConnectedSeriesFB`. - - Args: - layer_collection: The collection of all layers in the K-FAC approximate - Fisher information matrix to which this FisherBlock belongs. - has_bias: Whether the layer includes a bias parameter. - num_uses: int or None. Number of time-steps over which the layer - is used. Only required if the data is formatted with time folded into - the batch dimension (instead of time being a list dimension). - (Default: None) - option: A `SeriesFBApproximation` specifying the simplifying assumption - to be used in this block. `option1` approximates the cross-covariance - over time as a symmetric matrix, while `option2` makes - the assumption that training sequences are infinitely long. See section - 3.5 of the paper for more details. - """ - - self._has_bias = has_bias - self._option = option - - super(FullyConnectedSeriesFB, self).__init__( - layer_collection=layer_collection, - num_uses=num_uses) - - @property - def _num_timesteps(self): - return self._num_uses - - @property - def _renorm_coeff(self): - # This should no longer be used since the multiply_X functions from the base - # class have been overridden - assert False - - def instantiate_factors(self, grads_list, damping): - inputs, grads_list = self._process_data(grads_list) - - self._input_factor = self._layer_collection.make_or_get_factor( - fisher_factors.FullyConnectedMultiKF, - ((inputs,), self._num_uses, self._has_bias)) - self._input_factor.register_cov_dt1() - - self._output_factor = self._layer_collection.make_or_get_factor( - fisher_factors.FullyConnectedMultiKF, (grads_list, self._num_uses)) - self._output_factor.register_cov_dt1() - - self._setup_damping(damping, normalization=self._num_uses) - - def register_matpower(self, exp): - if exp != -1: - raise NotImplementedError("FullyConnectedSeriesFB only supports inverse" - "multiplications.") - - if self._option == SeriesFBApproximation.option1: - self._input_factor.register_option1quants(self._input_damping_func) - self._output_factor.register_option1quants(self._output_damping_func) - elif self._option == SeriesFBApproximation.option2: - self._input_factor.register_option2quants(self._input_damping_func) - self._output_factor.register_option2quants(self._output_damping_func) - else: - raise ValueError( - "Unrecognized FullyConnectedSeriesFB approximation: {}".format( - self._option)) - - def multiply_matpower(self, vector, exp): - if exp != -1: - raise NotImplementedError("FullyConnectedSeriesFB only supports inverse" - "multiplications.") - - # pylint: disable=invalid-name - - Z = utils.layer_params_to_mat2d(vector) - - # Derivations were done for "batch_dim==1" case so we need to convert to - # that orientation: - Z = array_ops.transpose(Z) - - if self._option == SeriesFBApproximation.option1: - - # Note that \\(L_A = A0^{-1/2} * U_A and L_G = G0^{-1/2} * U_G.\\) - L_A, psi_A = self._input_factor.get_option1quants( - self._input_damping_func) - L_G, psi_G = self._output_factor.get_option1quants( - self._output_damping_func) - - def gamma(x): - # We are assuming that each case has the same number of time-steps. - # If this stops being the case one shouldn't simply replace this T - # with its average value. Instead, one needs to go back to the - # definition of the gamma function from the paper. - T = self._num_timesteps - return (1 - x)**2 / (T * (1 - x**2) - 2 * x * (1 - x**T)) - - # \\(Y = \gamma( psi_G*psi_A^T )\\) (computed element-wise) - # Even though Y is Z-independent we are recomputing it from the psi's - # each since Y depends on both A and G quantities, and it is relatively - # cheap to compute. - Y = gamma(array_ops.reshape(psi_G, [int(psi_G.shape[0]), -1]) * psi_A) - - # \\(Z = L_G^T * Z * L_A\\) - # This is equivalent to the following computation from the original - # pseudo-code: - # \\(Z = G0^{-1/2} * Z * A0^{-1/2}\\) - # \\(Z = U_G^T * Z * U_A\\) - Z = math_ops.matmul(L_G, math_ops.matmul(Z, L_A), transpose_a=True) - - # \\(Z = Z .* Y\\) - Z *= Y - - # \\(Z = L_G * Z * L_A^T\\) - # This is equivalent to the following computation from the original - # pseudo-code: - # \\(Z = U_G * Z * U_A^T\\) - # \\(Z = G0^{-1/2} * Z * A0^{-1/2}\\) - Z = math_ops.matmul(L_G, math_ops.matmul(Z, L_A, transpose_b=True)) - - elif self._option == SeriesFBApproximation.option2: - - # Note that \\(P_A = A_1^T * A_0^{-1} and P_G = G_1^T * G_0^{-1}\\), - # and \\(K_A = A_0^{-1/2} * E_A\ and\ K_G = G_0^{-1/2} * E_G.\\) - P_A, K_A, mu_A = self._input_factor.get_option2quants( - self._input_damping_func) - P_G, K_G, mu_G = self._output_factor.get_option2quants( - self._output_damping_func) - - # Our approach differs superficially from the pseudo-code in the paper - # in order to reduce the total number of matrix-matrix multiplies. - # In particular, the first three computations in the pseudo code are - # \\(Z = G0^{-1/2} * Z * A0^{-1/2}\\) - # \\(Z = Z - hPsi_G^T * Z * hPsi_A\\) - # \\(Z = E_G^T * Z * E_A\\) - # Noting that hPsi = C0^{-1/2} * C1 * C0^{-1/2}\\), so that - # \\(C0^{-1/2} * hPsi = C0^{-1} * C1 * C0^{-1/2} = P^T * C0^{-1/2}\\) - # the entire computation can be written as - # \\(Z = E_G^T * (G0^{-1/2} * Z * A0^{-1/2}\\) - # \\( - hPsi_G^T * G0^{-1/2} * Z * A0^{-1/2} * hPsi_A) * E_A\\) - # \\( = E_G^T * (G0^{-1/2} * Z * A0^{-1/2}\\) - # \\( - G0^{-1/2} * P_G * Z * P_A^T * A0^{-1/2}) * E_A\\) - # \\( = E_G^T * G0^{-1/2} * Z * A0^{-1/2} * E_A\\) - # \\( - E_G^T* G0^{-1/2} * P_G * Z * P_A^T * A0^{-1/2} * E_A\\) - # \\( = K_G^T * Z * K_A - K_G^T * P_G * Z * P_A^T * K_A\\) - # This final expression is computed by the following two lines: - # \\(Z = Z - P_G * Z * P_A^T\\) - Z -= math_ops.matmul(P_G, math_ops.matmul(Z, P_A, transpose_b=True)) - # \\(Z = K_G^T * Z * K_A\\) - Z = math_ops.matmul(K_G, math_ops.matmul(Z, K_A), transpose_a=True) - - # \\(Z = Z ./ (1*1^T - mu_G*mu_A^T)\\) - # Be careful with the outer product. We don't want to accidentally - # make it an inner-product instead. - tmp = 1.0 - array_ops.reshape(mu_G, [int(mu_G.shape[0]), -1]) * mu_A - # Prevent some numerical issues by setting any 0.0 eigs to 1.0 - tmp += 1.0 * math_ops.cast(math_ops.equal(tmp, 0.0), dtype=tmp.dtype) - Z /= tmp - - # We now perform the transpose/reverse version of the operations - # derived above, whose derivation from the original pseudo-code is - # analgous. - # \\(Z = K_G * Z * K_A^T\\) - Z = math_ops.matmul(K_G, math_ops.matmul(Z, K_A, transpose_b=True)) - - # \\(Z = Z - P_G^T * Z * P_A\\) - Z -= math_ops.matmul(P_G, math_ops.matmul(Z, P_A), transpose_a=True) - - # \\(Z = normalize (1/E[T]) * Z\\) - # Note that this normalization is done because we compute the statistics - # by averaging, not summing, over time. (And the gradient is presumably - # summed over time, not averaged, and thus their scales are different.) - Z /= math_ops.cast(self._num_timesteps, Z.dtype) - - # Convert back to the "batch_dim==0" orientation. - Z = array_ops.transpose(Z) - - return utils.mat2d_to_layer_params(vector, Z) - - # pylint: enable=invalid-name - - def multiply_cholesky(self, vector): - raise NotImplementedError("FullyConnectedSeriesFB does not support " - "Cholesky computations.") - - def multiply_cholesky_inverse(self, vector): - raise NotImplementedError("FullyConnectedSeriesFB does not support " - "Cholesky computations.") - diff --git a/tensorflow/contrib/kfac/python/ops/fisher_blocks_lib.py b/tensorflow/contrib/kfac/python/ops/fisher_blocks_lib.py deleted file mode 100644 index c04cf727fa..0000000000 --- a/tensorflow/contrib/kfac/python/ops/fisher_blocks_lib.py +++ /dev/null @@ -1,45 +0,0 @@ -# Copyright 2017 The TensorFlow Authors. All Rights Reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -# ============================================================================== -"""FisherBlock definitions.""" - -from __future__ import absolute_import -from __future__ import division -from __future__ import print_function - -# pylint: disable=unused-import,line-too-long,wildcard-import -from tensorflow.contrib.kfac.python.ops.fisher_blocks import * -from tensorflow.python.util.all_util import remove_undocumented -# pylint: enable=unused-import,line-too-long,wildcard-import - -_allowed_symbols = [ - 'FisherBlock', - 'FullFB', - 'NaiveDiagonalFB', - 'FullyConnectedDiagonalFB', - 'KroneckerProductFB', - 'EmbeddingKFACFB', - 'FullyConnectedKFACBasicFB', - 'ConvKFCBasicFB', - 'ConvDiagonalFB', - 'set_global_constants', - 'compute_pi_tracenorm', - 'compute_pi_adjusted_damping', - 'num_conv_locations', - 'normalize_damping', - 'LEFT_MULTIPLY', - 'RIGHT_MULTIPLY', -] - -remove_undocumented(__name__, allowed_exception_list=_allowed_symbols) diff --git a/tensorflow/contrib/kfac/python/ops/fisher_factors.py b/tensorflow/contrib/kfac/python/ops/fisher_factors.py deleted file mode 100644 index afa2fd1ca7..0000000000 --- a/tensorflow/contrib/kfac/python/ops/fisher_factors.py +++ /dev/null @@ -1,1830 +0,0 @@ -# Copyright 2017 The TensorFlow Authors. All Rights Reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -# ============================================================================== -"""FisherFactor definitions.""" - -from __future__ import absolute_import -from __future__ import division -from __future__ import print_function - -import abc -import contextlib - -import numpy as np -import six - -from tensorflow.contrib.kfac.python.ops import linear_operator as lo -from tensorflow.contrib.kfac.python.ops import utils -from tensorflow.python.framework import dtypes -from tensorflow.python.framework import ops as tf_ops -from tensorflow.python.ops import array_ops -from tensorflow.python.ops import control_flow_ops -from tensorflow.python.ops import init_ops -from tensorflow.python.ops import linalg_ops -from tensorflow.python.ops import math_ops -from tensorflow.python.ops import random_ops -from tensorflow.python.ops import special_math_ops -from tensorflow.python.ops import variable_scope -from tensorflow.python.ops import variables -from tensorflow.python.training import moving_averages -from tensorflow.python.util import nest - - -# Whether to initialize covariance estimators at a zero matrix (or the identity -# matrix). -INIT_COVARIANCES_AT_ZERO = True - -# Whether to zero-debias the moving averages. -ZERO_DEBIAS = True - -# Whether to initialize inverse (and other such matrices computed from the cov -# matrices) to the zero matrix (or the identity matrix). -INIT_INVERSES_AT_ZERO = True - -# When the number of inverses requested from a FisherFactor exceeds this value, -# the inverses are computed using an eigenvalue decomposition. -EIGENVALUE_DECOMPOSITION_THRESHOLD = 2 - -# Numerical eigenvalues computed from covariance matrix estimates are clipped to -# be at least as large as this value before they are used to compute inverses or -# matrix powers. Must be nonnegative. -EIGENVALUE_CLIPPING_THRESHOLD = 0.0 - -# Used to subsample the flattened extracted image patches. The number of -# outer products per row of the covariance matrix should not exceed this -# value. This parameter is used only if `_SUB_SAMPLE_OUTER_PRODUCTS` is True. -_MAX_NUM_OUTER_PRODUCTS_PER_COV_ROW = 1 - -# Used to subsample the inputs passed to the extract image patches. The batch -# size of number of inputs to extract image patches is multiplied by this -# factor. This parameter is used only if `_SUB_SAMPLE_INPUTS` is True. -_INPUTS_TO_EXTRACT_PATCHES_FACTOR = 0.5 - -# If True, then subsamples the tensor passed to compute the covariance matrix. -_SUB_SAMPLE_OUTER_PRODUCTS = False - -# If True, then subsamples the tensor passed to compute the covariance matrix. -_SUB_SAMPLE_INPUTS = False - -# TOWER_STRATEGY can be one of "concat" or "separate". If "concat", the data -# passed to the factors from the blocks will be concatenated across towers -# (lazily via PartitionedTensor objects). Otherwise a tuple of tensors over -# towers will be passed in, and the factors will iterate over this and do the -# cov computations separately for each one, averaging the results together. -TOWER_STRATEGY = "concat" - - -def set_global_constants(init_covariances_at_zero=None, - zero_debias=None, - init_inverses_at_zero=None, - eigenvalue_decomposition_threshold=None, - eigenvalue_clipping_threshold=None, - max_num_outer_products_per_cov_row=None, - sub_sample_outer_products=None, - inputs_to_extract_patches_factor=None, - sub_sample_inputs=None, - tower_strategy=None): - """Sets various global constants used by the classes in this module.""" - global INIT_COVARIANCES_AT_ZERO - global ZERO_DEBIAS - global INIT_INVERSES_AT_ZERO - global EIGENVALUE_DECOMPOSITION_THRESHOLD - global EIGENVALUE_CLIPPING_THRESHOLD - global _MAX_NUM_OUTER_PRODUCTS_PER_COV_ROW - global _SUB_SAMPLE_OUTER_PRODUCTS - global _INPUTS_TO_EXTRACT_PATCHES_FACTOR - global _SUB_SAMPLE_INPUTS - global TOWER_STRATEGY - - if init_covariances_at_zero is not None: - INIT_COVARIANCES_AT_ZERO = init_covariances_at_zero - if zero_debias is not None: - ZERO_DEBIAS = zero_debias - if init_inverses_at_zero is not None: - INIT_INVERSES_AT_ZERO = init_inverses_at_zero - if eigenvalue_decomposition_threshold is not None: - EIGENVALUE_DECOMPOSITION_THRESHOLD = eigenvalue_decomposition_threshold - if eigenvalue_clipping_threshold is not None: - EIGENVALUE_CLIPPING_THRESHOLD = eigenvalue_clipping_threshold - if max_num_outer_products_per_cov_row is not None: - _MAX_NUM_OUTER_PRODUCTS_PER_COV_ROW = max_num_outer_products_per_cov_row - if sub_sample_outer_products is not None: - _SUB_SAMPLE_OUTER_PRODUCTS = sub_sample_outer_products - if inputs_to_extract_patches_factor is not None: - _INPUTS_TO_EXTRACT_PATCHES_FACTOR = inputs_to_extract_patches_factor - if sub_sample_inputs is not None: - _SUB_SAMPLE_INPUTS = sub_sample_inputs - if tower_strategy is not None: - TOWER_STRATEGY = tower_strategy - - -def inverse_initializer(shape, dtype, partition_info=None): # pylint: disable=unused-argument - if INIT_INVERSES_AT_ZERO: - return array_ops.zeros(shape, dtype=dtype) - return linalg_ops.eye(num_rows=shape[0], dtype=dtype) - - -def covariance_initializer(shape, dtype, partition_info=None): # pylint: disable=unused-argument - if INIT_COVARIANCES_AT_ZERO: - return array_ops.zeros(shape, dtype=dtype) - return linalg_ops.eye(num_rows=shape[0], dtype=dtype) - - -def diagonal_covariance_initializer(shape, dtype, partition_info=None): # pylint: disable=unused-argument - if INIT_COVARIANCES_AT_ZERO: - return array_ops.zeros(shape, dtype=dtype) - return array_ops.ones(shape, dtype=dtype) - - -@contextlib.contextmanager -def place_on_device(device): - if device is not None and len(device): - with tf_ops.device(device): - yield - else: - yield - - -def compute_cov(tensor, tensor_right=None, normalizer=None): - """Compute the empirical second moment of the rows of a 2D Tensor. - - This function is meant to be applied to random matrices for which the true row - mean is zero, so that the true second moment equals the true covariance. - - Args: - tensor: A 2D Tensor. - tensor_right: An optional 2D Tensor. If provided, this function computes - the matrix product tensor^T * tensor_right instead of tensor^T * tensor. - normalizer: optional scalar for the estimator (by default, the normalizer is - the number of rows of tensor). - - Returns: - A square 2D Tensor with as many rows/cols as the number of input columns. - """ - if normalizer is None: - normalizer = array_ops.shape(tensor)[0] - if tensor_right is None: - cov = ( - math_ops.matmul(tensor, tensor, transpose_a=True) / math_ops.cast( - normalizer, tensor.dtype)) - return (cov + array_ops.transpose(cov)) / math_ops.cast(2.0, cov.dtype) - else: - return (math_ops.matmul(tensor, tensor_right, transpose_a=True) / - math_ops.cast(normalizer, tensor.dtype)) - - -def append_homog(tensor): - """Appends a homogeneous coordinate to the last dimension of a Tensor. - - Args: - tensor: A Tensor. - - Returns: - A Tensor identical to the input but one larger in the last dimension. The - new entries are filled with ones. - """ - rank = len(tensor.shape.as_list()) - shape = array_ops.concat([array_ops.shape(tensor)[:-1], [1]], axis=0) - ones = array_ops.ones(shape, dtype=tensor.dtype) - return array_ops.concat([tensor, ones], axis=rank - 1) - - -def scope_string_from_params(params): - """Builds a variable scope string name from the given parameters. - - Supported parameters are: - * tensors - * booleans - * ints - * strings - * depth-1 tuples/lists of ints - * any depth tuples/lists of tensors - Other parameter types will throw an error. - - Args: - params: A parameter or list of parameters. - - Returns: - A string to use for the variable scope. - - Raises: - ValueError: if params includes an unsupported type. - """ - params = params if isinstance(params, (tuple, list)) else (params,) - - name_parts = [] - for param in params: - if param is None: - name_parts.append("None") - elif isinstance(param, (tuple, list)): - if all([isinstance(p, int) for p in param]): - name_parts.append("-".join([str(p) for p in param])) - else: - name_parts.append(scope_string_from_name(param)) - elif isinstance(param, (str, int, bool)): - name_parts.append(str(param)) - elif isinstance(param, (tf_ops.Tensor, variables.Variable)): - name_parts.append(scope_string_from_name(param)) - elif isinstance(param, utils.PartitionedTensor): - name_parts.append(scope_string_from_name(param.tensors)) - else: - raise ValueError("Encountered an unsupported param type {}".format( - type(param))) - return "_".join(name_parts) - - -def scope_string_from_name(tensor): - if isinstance(tensor, (tuple, list)): - return "__".join([scope_string_from_name(t) for t in tensor]) - # "gradients/add_4_grad/Reshape:0" -> "gradients_add_4_grad_Reshape" - return tensor.name.split(":")[0].replace("/", "_") - - -def scalar_or_tensor_to_string(val): - return repr(val) if np.isscalar(val) else scope_string_from_name(val) - - -def list_to_string(lst): - return "_".join(val if isinstance(val, six.string_types) - else scalar_or_tensor_to_string(val) for val in lst) - - -def graph_func_to_id(func): - """Returns a hashable object that represents func's computation.""" - # TODO(b/74201126): replace with Topohash of func's output - return func.func_id - - -def graph_func_to_string(func): - # TODO(b/74201126): replace with Topohash of func's output - return list_to_string(func.func_id) - - -def _subsample_for_cov_computation(array, name=None): - """Subsamples the first dimension of the array. - - `array`(A) is a tensor of shape `[batch_size, dim_2]`. Then the covariance - matrix(A^TA) is of shape `dim_2 ** 2`. Subsample only if the number of outer - products per row of the covariance matrix is greater than - `_MAX_NUM_OUTER_PRODUCTS_PER_COV_ROW`. - - Args: - array: Tensor, of shape `[batch_size, dim_2]`. - name: `string`, Default(None) - - Returns: - A tensor of shape `[max_samples, dim_2]`. - - Raises: - ValueError: If array's is not matrix-shaped. - ValueError: If array's batch_size cannot be inferred. - - """ - with tf_ops.name_scope(name, "subsample", [array]): - array = tf_ops.convert_to_tensor(array) - if len(array.shape) != 2: - raise ValueError("Input param array must be a matrix.") - - batch_size = array.shape.as_list()[0] - if batch_size is None: - raise ValueError("Unable to get batch_size from input param array.") - - num_cov_rows = array.shape.as_list()[-1] - max_batch_size = int(_MAX_NUM_OUTER_PRODUCTS_PER_COV_ROW * num_cov_rows) - if batch_size <= max_batch_size: - return array - - return _random_tensor_gather(array, max_batch_size) - - -def _random_tensor_gather(array, max_size): - """Generates a random set of indices and gathers the value at the indices. - - Args: - array: Tensor, of shape `[batch_size, dim_2]`. - max_size: int, Number of indices to sample. - - Returns: - A tensor of shape `[max_size, ...]`. - """ - batch_size = array.shape.as_list()[0] - indices = random_ops.random_shuffle(math_ops.range(0, batch_size))[:max_size] - return array_ops.gather(array, indices) - - -@six.add_metaclass(abc.ABCMeta) -class FisherFactor(object): - """Base class for objects modeling factors of approximate Fisher blocks. - - A FisherFactor represents part of an approximate Fisher Information matrix. - For example, one approximation to the Fisher uses the Kronecker product of two - FisherFactors A and B, F = kron(A, B). FisherFactors are composed with - FisherBlocks to construct a block-diagonal approximation to the full Fisher. - - FisherFactors are backed by a single, non-trainable variable that is updated - by running FisherFactor.make_covariance_update_op(). The shape and type of - this variable is implementation specific. - - Note that for blocks that aren't based on approximations, a 'factor' can - be the entire block itself, as is the case for the diagonal and full - representations. - """ - - def __init__(self): - self._cov = None - - @abc.abstractproperty - def _var_scope(self): - """Variable scope for this FisherFactor instance. - - Returns: - string that unique identifies this FisherFactor instance. - """ - pass - - @property - def name(self): - return self._var_scope - - @abc.abstractproperty - def _cov_shape(self): - """The shape of the variable backing this FisherFactor.""" - pass - - @abc.abstractproperty - def _num_sources(self): - """The number of things to sum over when updating covariance variable. - - The default make_covariance_update_op function will call _compute_new_cov - with indices ranging from 0 to _num_sources-1. The typical situation is - where the factor wants to sum the statistics it computes over multiple - backpropped "gradients" (typically passed in via "tensors" or - "outputs_grads" arguments). - """ - pass - - @abc.abstractproperty - def _num_towers(self): - pass - - @abc.abstractproperty - def _dtype(self): - """dtype for variable backing this factor.""" - pass - - @property - def _cov_initializer(self): - """Function for initializing covariance variable.""" - return covariance_initializer - - def instantiate_cov_variables(self): - """Makes the internal cov variable(s).""" - assert self._cov is None - with variable_scope.variable_scope(self._var_scope): - self._cov = variable_scope.get_variable( - "cov", - initializer=self._cov_initializer, - shape=self._cov_shape, - trainable=False, - dtype=self._dtype) - - @abc.abstractmethod - def _compute_new_cov(self, source, tower): - """Computes minibatch-estimated covariance for a single source. - - Args: - source: int in [0, self._num_sources). Which source to use when computing - the cov update. - tower: int in [0, self._num_towers). Which tower to use when computing - the cov update. - - Returns: - Tensor of same shape as self.get_cov(). - """ - pass - - def make_covariance_update_op(self, ema_decay): - """Constructs and returns the covariance update Op. - - Args: - ema_decay: The exponential moving average decay (float or Tensor). - Returns: - An Op for updating the covariance Variable referenced by _cov. - """ - new_cov_contribs = [] - for source in range(self._num_sources): - for tower in range(self._num_towers): - device = (self._get_data_device(tower) - if TOWER_STRATEGY == "separate" else None) - with place_on_device(device): - new_cov_contribs.append(self._compute_new_cov(source, tower)) - - new_cov = math_ops.add_n(new_cov_contribs) / float(self._num_towers) - - # Compute average of 'new_cov' across all TPU cores. On a TPU, each - # instance of 'new_cov' will be based on a different minibatch. This ensures - # that by the end of assign_moving_average(), all TPU cores see the same - # value for self._cov. - # - # Other implementations of make_covariance_update_op() that accumulate - # statistics in other variables should mimic this behavior. - if utils.on_tpu(): - new_cov = utils.cross_replica_mean(new_cov) - - return moving_averages.assign_moving_average( - self._cov, new_cov, ema_decay, zero_debias=ZERO_DEBIAS) - - @abc.abstractmethod - def _get_data_device(self, tower): - pass - - @abc.abstractmethod - def instantiate_inv_variables(self): - """Makes the internal "inverse" variable(s).""" - pass - - @abc.abstractmethod - def make_inverse_update_ops(self): - """Create and return update ops corresponding to registered computations.""" - pass - - def get_cov(self): - return self._cov - - @abc.abstractmethod - def get_cov_as_linear_operator(self): - pass - - @abc.abstractmethod - def register_matpower(self, exp, damping_func): - pass - - @abc.abstractmethod - def register_cholesky(self, damping_func): - pass - - @abc.abstractmethod - def register_cholesky_inverse(self, damping_func): - pass - - @abc.abstractmethod - def get_matpower(self, exp, damping_func): - pass - - @abc.abstractmethod - def get_cholesky(self, damping_func): - pass - - @abc.abstractmethod - def get_cholesky_inverse(self, damping_func): - pass - - -class DenseSquareMatrixFactor(FisherFactor): - """Base class for FisherFactors that are stored as dense square matrices. - - This class explicitly calculates and stores inverses of their `cov` matrices, - which must be square dense matrices. - - Subclasses must implement the _compute_new_cov method, and the _var_scope and - _cov_shape properties. - """ - - # TODO(b/69108481): This class (and its subclasses) should be refactored to - # serve the matrix quantities it computes as both (potentially stale) - # variables, updated by the inverse update ops, and fresh values stored in - # tensors that recomputed once every session.run() call. Currently matpower - # and damp_inverse have the former behavior, while eigendecomposition has - # the latter. - - def __init__(self): - self._matpower_by_exp_and_damping = {} # { (float, hashable): variable } - self._matpower_registrations = set() # { (float, hashable) } - self._eigendecomp = None - self._damping_funcs_by_id = {} # {hashable: lambda} - - self._cholesky_registrations = set() # { hashable } - self._cholesky_inverse_registrations = set() # { hashable } - - self._cholesky_by_damping = {} # { hashable: variable } - self._cholesky_inverse_by_damping = {} # { hashable: variable } - - super(DenseSquareMatrixFactor, self).__init__() - - def get_cov_as_linear_operator(self): - assert self.get_cov().shape.ndims == 2 - return lo.LinearOperatorFullMatrix(self.get_cov(), - is_self_adjoint=True, - is_square=True) - - def _register_damping(self, damping_func): - damping_id = graph_func_to_id(damping_func) - if damping_id not in self._damping_funcs_by_id: - self._damping_funcs_by_id[damping_id] = damping_func - return damping_id - - def register_inverse(self, damping_func): - # Just for backwards compatibility of some old code and tests - self.register_matpower(-1, damping_func) - - def register_matpower(self, exp, damping_func): - """Registers a matrix power to be maintained and served on demand. - - This creates a variable and signals make_inverse_update_ops to make the - corresponding update op. The variable can be read via the method - get_matpower. - - Args: - exp: float. The exponent to use in the matrix power. - damping_func: A function that computes a 0-D Tensor or a float which will - be the damping value used. i.e. damping = damping_func(). - """ - if exp == 1.0: - return - - damping_id = self._register_damping(damping_func) - - if (exp, damping_id) not in self._matpower_registrations: - self._matpower_registrations.add((exp, damping_id)) - - def register_cholesky(self, damping_func): - """Registers a Cholesky factor to be maintained and served on demand. - - This creates a variable and signals make_inverse_update_ops to make the - corresponding update op. The variable can be read via the method - get_cholesky. - - Args: - damping_func: A function that computes a 0-D Tensor or a float which will - be the damping value used. i.e. damping = damping_func(). - """ - damping_id = self._register_damping(damping_func) - - if damping_id not in self._cholesky_registrations: - self._cholesky_registrations.add(damping_id) - - def register_cholesky_inverse(self, damping_func): - """Registers an inverse Cholesky factor to be maintained/served on demand. - - This creates a variable and signals make_inverse_update_ops to make the - corresponding update op. The variable can be read via the method - get_cholesky_inverse. - - Args: - damping_func: A function that computes a 0-D Tensor or a float which will - be the damping value used. i.e. damping = damping_func(). - """ - damping_id = self._register_damping(damping_func) - - if damping_id not in self._cholesky_inverse_registrations: - self._cholesky_inverse_registrations.add(damping_id) - - def instantiate_inv_variables(self): - """Makes the internal "inverse" variable(s).""" - - for (exp, damping_id) in self._matpower_registrations: - exp_string = scalar_or_tensor_to_string(exp) - damping_func = self._damping_funcs_by_id[damping_id] - damping_string = graph_func_to_string(damping_func) - with variable_scope.variable_scope(self._var_scope): - matpower = variable_scope.get_variable( - "matpower_exp{}_damp{}".format(exp_string, damping_string), - initializer=inverse_initializer, - shape=self._cov_shape, - trainable=False, - dtype=self._dtype) - assert (exp, damping_id) not in self._matpower_by_exp_and_damping - self._matpower_by_exp_and_damping[(exp, damping_id)] = matpower - - for damping_id in self._cholesky_registrations: - damping_func = self._damping_funcs_by_id[damping_id] - damping_string = graph_func_to_string(damping_func) - with variable_scope.variable_scope(self._var_scope): - chol = variable_scope.get_variable( - "cholesky_damp{}".format(damping_string), - initializer=inverse_initializer, - shape=self._cov_shape, - trainable=False, - dtype=self._dtype) - assert damping_id not in self._cholesky_by_damping - self._cholesky_by_damping[damping_id] = chol - - for damping_id in self._cholesky_inverse_registrations: - damping_func = self._damping_funcs_by_id[damping_id] - damping_string = graph_func_to_string(damping_func) - with variable_scope.variable_scope(self._var_scope): - cholinv = variable_scope.get_variable( - "cholesky_inverse_damp{}".format(damping_string), - initializer=inverse_initializer, - shape=self._cov_shape, - trainable=False, - dtype=self._dtype) - assert damping_id not in self._cholesky_inverse_by_damping - self._cholesky_inverse_by_damping[damping_id] = cholinv - - def make_inverse_update_ops(self): - """Create and return update ops corresponding to registered computations.""" - ops = [] - - num_inverses = sum(1 for (exp, _) in self._matpower_by_exp_and_damping - if exp == -1) - - num_other_matpower = len(self._matpower_by_exp_and_damping) - num_inverses - - other_matrix_power_registered = num_other_matpower >= 1 - - use_eig = ( - self._eigendecomp or other_matrix_power_registered or - num_inverses >= EIGENVALUE_DECOMPOSITION_THRESHOLD) - - # We precompute these so we don't need to evaluate them multiple times (for - # each matrix power that uses them) - damping_value_by_id = {damping_id: math_ops.cast( - self._damping_funcs_by_id[damping_id](), self._dtype) - for damping_id in self._damping_funcs_by_id} - - if use_eig: - eigenvalues, eigenvectors = self.get_eigendecomp() # pylint: disable=unpacking-non-sequence - - for (exp, damping_id), matpower in ( - self._matpower_by_exp_and_damping.items()): - damping = damping_value_by_id[damping_id] - ops.append( - matpower.assign( - math_ops.matmul(eigenvectors * - (eigenvalues + damping)**exp, - array_ops.transpose(eigenvectors)))) - # These ops share computation and should be run on a single device. - ops = [control_flow_ops.group(*ops)] - else: - for (exp, damping_id), matpower in ( - self._matpower_by_exp_and_damping.items()): - assert exp == -1 - damping = damping_value_by_id[damping_id] - ops.append(matpower.assign(utils.posdef_inv(self.get_cov(), damping))) - - # TODO(b/77902055): If inverses are being computed with Cholesky's - # we can share the work. Instead this code currently just computes the - # Cholesky a second time. It does at least share work between requests for - # Cholesky's and Cholesky inverses with the same damping id. - for damping_id, cholesky_inv in self._cholesky_inverse_by_damping.items(): - cholesky_ops = [] - - damping = damping_value_by_id[damping_id] - cholesky_value = utils.cholesky(self.get_cov(), damping) - - if damping_id in self._cholesky_by_damping: - cholesky = self._cholesky_by_damping[damping_id] - cholesky_ops.append(cholesky.assign(cholesky_value)) - - identity = linalg_ops.eye(cholesky_value.shape.as_list()[0], - dtype=cholesky_value.dtype) - cholesky_inv_value = linalg_ops.matrix_triangular_solve(cholesky_value, - identity) - cholesky_ops.append(cholesky_inv.assign(cholesky_inv_value)) - - ops.append(control_flow_ops.group(*cholesky_ops)) - - for damping_id, cholesky in self._cholesky_by_damping.items(): - if damping_id not in self._cholesky_inverse_by_damping: - damping = damping_value_by_id[damping_id] - cholesky_value = utils.cholesky(self.get_cov(), damping) - ops.append(cholesky.assign(cholesky_value)) - - self._eigendecomp = False - return ops - - def get_inverse(self, damping_func): - # Just for backwards compatibility of some old code and tests - return self.get_matpower(-1, damping_func) - - def get_matpower(self, exp, damping_func): - # Note that this function returns a variable which gets updated by the - # inverse ops. It may be stale / inconsistent with the latest value of - # get_cov(). - if exp != 1: - damping_id = graph_func_to_id(damping_func) - matpower = self._matpower_by_exp_and_damping[(exp, damping_id)] - else: - matpower = self.get_cov() - identity = linalg_ops.eye(matpower.shape.as_list()[0], - dtype=matpower.dtype) - matpower += math_ops.cast(damping_func(), dtype=matpower.dtype)*identity - - assert matpower.shape.ndims == 2 - return lo.LinearOperatorFullMatrix(matpower, - is_non_singular=True, - is_self_adjoint=True, - is_positive_definite=True, - is_square=True) - - def get_cholesky(self, damping_func): - # Note that this function returns a variable which gets updated by the - # inverse ops. It may be stale / inconsistent with the latest value of - # get_cov(). - damping_id = graph_func_to_id(damping_func) - cholesky = self._cholesky_by_damping[damping_id] - assert cholesky.shape.ndims == 2 - return lo.LinearOperatorFullMatrix(cholesky, - is_non_singular=True, - is_square=True) - - def get_cholesky_inverse(self, damping_func): - # Note that this function returns a variable which gets updated by the - # inverse ops. It may be stale / inconsistent with the latest value of - # get_cov(). - damping_id = graph_func_to_id(damping_func) - cholesky_inv = self._cholesky_inverse_by_damping[damping_id] - assert cholesky_inv.shape.ndims == 2 - return lo.LinearOperatorFullMatrix(cholesky_inv, - is_non_singular=True, - is_square=True) - - def get_eigendecomp(self): - """Creates or retrieves eigendecomposition of self._cov.""" - # Unlike get_matpower this doesn't retrieve a stored variable, but instead - # always computes a fresh version from the current value of get_cov(). - if not self._eigendecomp: - eigenvalues, eigenvectors = linalg_ops.self_adjoint_eig(self.get_cov()) - - # The matrix self._cov is positive semidefinite by construction, but the - # numerical eigenvalues could be negative due to numerical errors, so here - # we clip them to be at least FLAGS.eigenvalue_clipping_threshold - clipped_eigenvalues = math_ops.maximum(eigenvalues, - EIGENVALUE_CLIPPING_THRESHOLD) - self._eigendecomp = (clipped_eigenvalues, eigenvectors) - - return self._eigendecomp - - -class FullFactor(DenseSquareMatrixFactor): - """FisherFactor for a full matrix representation of the Fisher of a parameter. - - Note that this uses the naive "square the sum estimator", and so is applicable - to any type of parameter in principle, but has very high variance. - """ - - def __init__(self, - params_grads, - batch_size): - self._batch_size = batch_size - self._params_grads = tuple(utils.ensure_sequence(params_grad) - for params_grad in params_grads) - super(FullFactor, self).__init__() - - @property - def _var_scope(self): - return "ff_full_" + scope_string_from_params( - [self._params_grads, self._batch_size]) - - @property - def _cov_shape(self): - size = sum(param_grad.shape.num_elements() - for param_grad in self._params_grads[0]) - return (size, size) - - @property - def _num_sources(self): - return len(self._params_grads) - - @property - def _num_towers(self): - return 1 - - @property - def _dtype(self): - return self._params_grads[0][0].dtype - - def _compute_new_cov(self, source, tower): - assert tower == 0 - - # This will be a very basic rank 1 estimate - params_grads_flat = utils.tensors_to_column(self._params_grads[source]) - return ((params_grads_flat * array_ops.transpose( - params_grads_flat)) / math_ops.cast(self._batch_size, - params_grads_flat.dtype)) - - def _get_data_device(self, tower): - return None - - -class DiagonalFactor(FisherFactor): - """A base class for FisherFactors that use diagonal approximations. - - A DiagonalFactor's covariance variable can be of any shape, but must contain - exactly one entry per parameter. - """ - - def __init__(self): - super(DiagonalFactor, self).__init__() - - def get_cov_as_linear_operator(self): - assert self._matrix_diagonal.shape.ndims == 1 - return lo.LinearOperatorDiag(self._matrix_diagonal, - is_self_adjoint=True, - is_square=True) - - @property - def _cov_initializer(self): - return diagonal_covariance_initializer - - @property - def _matrix_diagonal(self): - return array_ops.reshape(self.get_cov(), [-1]) - - def make_inverse_update_ops(self): - return [] - - def instantiate_inv_variables(self): - pass - - def register_matpower(self, exp, damping_func): - pass - - def register_cholesky(self, damping_func): - pass - - def register_cholesky_inverse(self, damping_func): - pass - - def get_matpower(self, exp, damping_func): - matpower_diagonal = (self._matrix_diagonal - + math_ops.cast(damping_func(), self._dtype))**exp - return lo.LinearOperatorDiag(matpower_diagonal, - is_non_singular=True, - is_self_adjoint=True, - is_positive_definite=True, - is_square=True) - - def get_cholesky(self, damping_func): - return self.get_matpower(0.5, damping_func) - - def get_cholesky_inverse(self, damping_func): - return self.get_matpower(-0.5, damping_func) - - -class NaiveDiagonalFactor(DiagonalFactor): - """FisherFactor for a diagonal approximation of any type of param's Fisher. - - Note that this uses the naive "square the sum estimator", and so is applicable - to any type of parameter in principle, but has very high variance. - """ - - def __init__(self, - params_grads, - batch_size): - """Initializes NaiveDiagonalFactor instance. - - Args: - params_grads: Sequence of Tensors, each with same shape as parameters this - FisherFactor corresponds to. For example, the gradient of the loss with - respect to parameters. - batch_size: int or 0-D Tensor. Size - """ - self._params_grads = tuple(utils.ensure_sequence(params_grad) - for params_grad in params_grads) - self._batch_size = batch_size - super(NaiveDiagonalFactor, self).__init__() - - @property - def _var_scope(self): - return "ff_naivediag_" + scope_string_from_params( - [self._params_grads, self._batch_size]) - - @property - def _cov_shape(self): - size = sum(param_grad.shape.num_elements() - for param_grad in self._params_grads[0]) - return [size, 1] - - @property - def _num_sources(self): - return len(self._params_grads) - - @property - def _num_towers(self): - return 1 - - @property - def _dtype(self): - return self._params_grads[0][0].dtype - - def _compute_new_cov(self, source, tower): - assert tower == 0 - - params_grads_flat = utils.tensors_to_column(self._params_grads[source]) - return (math_ops.square(params_grads_flat) / math_ops.cast( - self._batch_size, params_grads_flat.dtype)) - - def _get_data_device(self, tower): - return None - - -class EmbeddingInputKroneckerFactor(DiagonalFactor): - r"""FisherFactor for input to an embedding layer. - - Given input_ids = [batch_size, input_size] representing indices into an - [vocab_size, embedding_size] embedding matrix, approximate input covariance by - a diagonal matrix, - - Cov(input_ids, input_ids) = - (1/batch_size) sum_{i} diag(n_hot(input[i]) ** 2). - - where n_hot() constructs an n-hot binary vector and diag() constructs a - diagonal matrix of size [vocab_size, vocab_size]. - """ - - def __init__(self, input_ids, vocab_size, dtype=None): - """Instantiate EmbeddingInputKroneckerFactor. - - Args: - input_ids: List of Tensors of shape [batch_size, input_size] and dtype - int32. Indices into embedding matrix. List index is tower. - vocab_size: int or 0-D Tensor. Maximum value for entries in 'input_ids'. - dtype: dtype for covariance statistics. Must be a floating point type. - Defaults to float32. - """ - self._input_ids = input_ids - self._vocab_size = vocab_size - self._cov_dtype = dtype or dtypes.float32 - - super(EmbeddingInputKroneckerFactor, self).__init__() - - @property - def _var_scope(self): - return "ff_diag_embedding_" + scope_string_from_params(self._input_ids) - - @property - def _cov_shape(self): - return [self._vocab_size] - - @property - def _num_sources(self): - return 1 - - @property - def _num_towers(self): - return len(self._input_ids) - - @property - def _dtype(self): - return self._cov_dtype - - def _compute_new_cov(self, source, tower): - assert source == 0 - - input_ids = self._input_ids[tower] - - if len(input_ids.shape) > 2: - raise ValueError( - "Input to embeddings must have rank <= 2. Found rank %d." % len( - input_ids.shape)) - - batch_size = array_ops.shape(input_ids)[0] - - # Transform indices into one-hot vectors. - # - # TODO(b/72714822): There must be a faster way to construct the diagonal - # covariance matrix! This operation is O(batch_size * vocab_size), where - # it should be O(batch_size * input_size). - flat_input_ids = array_ops.reshape(input_ids, [-1]) - one_hots = array_ops.one_hot(flat_input_ids, - self._vocab_size) # [?, vocab_size] - - # Take average across examples. Note that, because all entries have - # magnitude zero or one, there's no need to square the entries. - # - # TODO(b/72714822): Support for SparseTensor, other kinds of aggregation - # within an example such as average. - # - # TODO(b/72714822): Support for partitioned embeddings. - new_cov = math_ops.reduce_sum(one_hots, axis=0) # [vocab_size] - new_cov /= math_ops.cast(batch_size, new_cov.dtype) - - return new_cov - - def _get_data_device(self, tower): - return self._input_ids[tower].device - - -class FullyConnectedDiagonalFactor(DiagonalFactor): - r"""FisherFactor for a diagonal approx of a fully-connected layer's Fisher. - - Given in = [batch_size, input_size] and out_grad = [batch_size, output_size], - approximates the covariance as, - - Cov(in, out) = (1/batch_size) sum_{i} outer(in[i], out_grad[i]) ** 2.0 - - where the square is taken element-wise. - """ - - def __init__(self, - inputs, - outputs_grads, - has_bias=False): - """Instantiate FullyConnectedDiagonalFactor. - - Args: - inputs: List of Tensors of shape [batch_size, input_size]. Inputs to this - layer. List index is towers. - outputs_grads: List of Tensors, each of shape [batch_size, output_size], - which are the gradients of the loss with respect to the layer's - outputs. First index is source, second is tower. - - has_bias: bool. If True, append '1' to each input. - """ - self._inputs = inputs - self._has_bias = has_bias - self._outputs_grads = outputs_grads - self._squared_inputs = None - - super(FullyConnectedDiagonalFactor, self).__init__() - - @property - def _var_scope(self): - return "ff_diagfc_" + scope_string_from_params( - tuple(self._inputs) + tuple(nest.flatten(self._outputs_grads))) - - @property - def _cov_shape(self): - input_size = self._inputs[0].shape[1] + self._has_bias - output_size = self._outputs_grads[0][0].shape[1] - return [input_size, output_size] - - @property - def _num_sources(self): - return len(self._outputs_grads) - - @property - def _num_towers(self): - return len(self._inputs) - - @property - def _dtype(self): - return self._outputs_grads[0][0].dtype - - def make_covariance_update_op(self, ema_decay): - - self._squared_inputs = [] - for tower in range(self._num_towers): - inputs = self._inputs[tower] - - with place_on_device(self._get_data_device(tower)): - if self._has_bias: - inputs = append_homog(inputs) - self._squared_inputs.append(math_ops.square(inputs)) - - return super(FullyConnectedDiagonalFactor, self).make_covariance_update_op( - ema_decay) - - def _compute_new_cov(self, source, tower): - batch_size = array_ops.shape(self._squared_inputs[tower])[0] - outputs_grad = self._outputs_grads[source][tower] - - # The well-known special formula that uses the fact that the entry-wise - # square of an outer product is the outer-product of the entry-wise squares. - # The gradient is the outer product of the input and the output gradients, - # so we just square both and then take their outer-product. - new_cov = math_ops.matmul( - self._squared_inputs[tower], - math_ops.square(outputs_grad), - transpose_a=True) - new_cov /= math_ops.cast(batch_size, new_cov.dtype) - return new_cov - - def _get_data_device(self, tower): - return self._inputs[tower].device - - -class ConvDiagonalFactor(DiagonalFactor): - """FisherFactor for a diagonal approx of a convolutional layer's Fisher.""" - - def __init__(self, - inputs, - outputs_grads, - filter_shape, - strides, - padding, - data_format=None, - dilations=None, - has_bias=False): - """Creates a ConvDiagonalFactor object. - - Args: - inputs: List of Tensors of shape [batch_size, height, width, in_channels]. - Input activations to this layer. List index is towers. - outputs_grads: List of Tensors, each of shape [batch_size, - height, width, out_channels], which are the gradients of the loss - with respect to the layer's outputs. First index is source, second - index is tower. - filter_shape: Tuple of 4 ints: (kernel_height, kernel_width, in_channels, - out_channels). Represents shape of kernel used in this layer. - strides: The stride size in this layer (1-D Tensor of length 4). - padding: The padding in this layer (1-D of Tensor length 4). - data_format: None or str. Format of conv2d inputs. - dilations: None or tuple of 4 ints. - has_bias: Python bool. If True, the layer is assumed to have a bias - parameter in addition to its filter parameter. - - Raises: - ValueError: If inputs, output_grads, and filter_shape do not agree on - in_channels or out_channels. - ValueError: If strides, dilations are not length-4 lists of ints. - ValueError: If data_format does not put channel last. - """ - if not utils.is_data_format_channel_last(data_format): - raise ValueError("Channel must be last.") - if any(input_.shape.ndims != 4 for input_ in inputs): - raise ValueError("inputs must be a list of 4-D Tensors.") - if any(input_.shape.as_list()[-1] != filter_shape[-2] for input_ in inputs): - raise ValueError("inputs and filter_shape must agree on in_channels.") - for i, outputs_grad in enumerate(outputs_grads): - if any(output_grad.shape.ndims != 4 for output_grad in outputs_grad): - raise ValueError("outputs[%d] must be 4-D Tensor." % i) - if any(output_grad.shape.as_list()[-1] != filter_shape[-1] - for output_grad in outputs_grad): - raise ValueError( - "outputs[%d] and filter_shape must agree on out_channels." % i) - if len(strides) != 4: - raise ValueError("strides must be length-4 list of ints.") - if dilations is not None and len(dilations) != 4: - raise ValueError("dilations must be length-4 list of ints.") - - self._inputs = inputs - self._outputs_grads = outputs_grads - self._filter_shape = filter_shape - self._strides = strides - self._padding = padding - self._data_format = data_format - self._dilations = dilations - self._has_bias = has_bias - self._patches = None - - super(ConvDiagonalFactor, self).__init__() - - @property - def _var_scope(self): - return "ff_convdiag_" + scope_string_from_params( - tuple(self._inputs) + tuple(nest.flatten(self._outputs_grads))) - - @property - def _cov_shape(self): - filter_height, filter_width, in_channels, out_channels = self._filter_shape - return [ - filter_height * filter_width * in_channels + self._has_bias, - out_channels - ] - - @property - def _num_sources(self): - return len(self._outputs_grads) - - @property - def _num_towers(self): - return len(self._inputs) - - @property - def _dtype(self): - return self._inputs[0].dtype - - def make_covariance_update_op(self, ema_decay): - filter_height, filter_width, _, _ = self._filter_shape - - # TODO(b/64144716): there is potential here for a big savings in terms - # of memory use. - if self._dilations is None: - rates = (1, 1, 1, 1) - else: - rates = tuple(self._dilations) - - self._patches = [] - for tower in range(self._num_towers): - with place_on_device(self._get_data_device(tower)): - patches = array_ops.extract_image_patches( - self._inputs[tower], - ksizes=[1, filter_height, filter_width, 1], - strides=self._strides, - rates=rates, - padding=self._padding) - - if self._has_bias: - patches = append_homog(patches) - - self._patches.append(patches) - - return super(ConvDiagonalFactor, self).make_covariance_update_op(ema_decay) - - def _compute_new_cov(self, source, tower): - patches = self._patches[tower] - batch_size = array_ops.shape(patches)[0] - outputs_grad = self._outputs_grads[source][tower] - - new_cov = self._convdiag_sum_of_squares(patches, outputs_grad) - new_cov /= math_ops.cast(batch_size, new_cov.dtype) - - return new_cov - - def _convdiag_sum_of_squares(self, patches, outputs_grad): - # This computes the sum of the squares of the per-training-case "gradients". - # It does this simply by computing a giant tensor containing all of these, - # doing an entry-wise square, and them summing along the batch dimension. - case_wise_gradients = special_math_ops.einsum("bijk,bijl->bkl", patches, - outputs_grad) - return math_ops.reduce_sum(math_ops.square(case_wise_gradients), axis=0) - - def _get_data_device(self, tower): - return self._inputs[tower].device - - -class FullyConnectedKroneckerFactor(DenseSquareMatrixFactor): - """Kronecker factor for the input or output side of a fully-connected layer. - """ - - def __init__(self, - tensors, - has_bias=False): - """Instantiate FullyConnectedKroneckerFactor. - - Args: - tensors: List of list of Tensors, each of shape [batch_size, n]. The - Tensors are typically either a layer's inputs or its output's gradients. - The first list index is source, the second is tower. - has_bias: bool. If True, append '1' to each row. - """ - # The tensor argument is either a tensor of input activations or a tensor of - # output pre-activation gradients. - self._has_bias = has_bias - self._tensors = tensors - super(FullyConnectedKroneckerFactor, self).__init__() - - @property - def _var_scope(self): - return "ff_fckron_" + scope_string_from_params( - tuple(nest.flatten(self._tensors)) + (self._has_bias,)) - - @property - def _cov_shape(self): - size = self._tensors[0][0].shape[1] + self._has_bias - return [size, size] - - @property - def _num_sources(self): - return len(self._tensors) - - @property - def _num_towers(self): - return len(self._tensors[0]) - - @property - def _dtype(self): - return self._tensors[0][0].dtype - - def _compute_new_cov(self, source, tower): - tensor = self._tensors[source][tower] - if self._has_bias: - tensor = append_homog(tensor) - return compute_cov(tensor) - - def _get_data_device(self, tower): - return self._tensors[0][tower].device - - -class ConvInputKroneckerFactor(DenseSquareMatrixFactor): - r"""Kronecker factor for the input side of a convolutional layer. - - Estimates E[ a a^T ] where a is the inputs to a convolutional layer given - example x. Expectation is taken over all examples and locations. - - Equivalent to Omega in https://arxiv.org/abs/1602.01407 for details. See - Section 3.1 Estimating the factors. - """ - - def __init__(self, - inputs, - filter_shape, - padding, - strides=None, - dilation_rate=None, - data_format=None, - extract_patches_fn=None, - has_bias=False, - sub_sample_inputs=None, - sub_sample_patches=None): - """Initializes ConvInputKroneckerFactor. - - Args: - inputs: List of Tensors of shape [batch_size, ..spatial_input_size.., - in_channels]. Inputs to layer. List index is tower. - filter_shape: List of ints. Contains [..spatial_filter_size.., - in_channels, out_channels]. Shape of convolution kernel. - padding: str. Padding method for layer. "SAME" or "VALID". - strides: List of ints or None. Contains [..spatial_filter_strides..] if - 'extract_patches_fn' is compatible with tf.nn.convolution(), else - [1, ..spatial_filter_strides, 1]. - dilation_rate: List of ints or None. Rate for dilation along each spatial - dimension if 'extract_patches_fn' is compatible with - tf.nn.convolution(), else [1, ..spatial_dilation_rates.., 1]. - data_format: str or None. Format of input data. - extract_patches_fn: str or None. Name of function that extracts image - patches. One of "extract_convolution_patches", "extract_image_patches", - "extract_pointwise_conv2d_patches". - has_bias: bool. If True, append 1 to in_channel. - sub_sample_inputs: `bool`. If True, then subsample the inputs from which - the image patches are extracted. (Default: None) - sub_sample_patches: `bool`, If `True` then subsample the extracted - patches.(Default: None) - """ - self._inputs = inputs - self._filter_shape = filter_shape - self._strides = strides - self._padding = padding - self._dilation_rate = dilation_rate - self._data_format = data_format - self._extract_patches_fn = extract_patches_fn - self._has_bias = has_bias - if sub_sample_inputs is None: - self._sub_sample_inputs = _SUB_SAMPLE_INPUTS - else: - self._sub_sample_inputs = sub_sample_inputs - - if sub_sample_patches is None: - self._sub_sample_patches = _SUB_SAMPLE_OUTER_PRODUCTS - else: - self._sub_sample_patches = sub_sample_patches - super(ConvInputKroneckerFactor, self).__init__() - - @property - def _var_scope(self): - return "ff_convinkron_" + scope_string_from_params( - tuple(self._inputs) + - tuple((self._filter_shape, self._strides, self._padding, - self._dilation_rate, self._data_format, self._has_bias))) - - @property - def _cov_shape(self): - spatial_filter_shape = self._filter_shape[0:-2] - in_channels = self._filter_shape[-2] - size = np.prod(spatial_filter_shape) * in_channels + self._has_bias - return [size, size] - - @property - def _num_sources(self): - return 1 - - @property - def _num_towers(self): - return len(self._inputs) - - @property - def _dtype(self): - return self._inputs[0].dtype - - def _compute_new_cov(self, source, tower): - assert source == 0 - - inputs = self._inputs[tower] - if self._sub_sample_inputs: - batch_size = inputs.shape.as_list()[0] - max_size = int(batch_size * _INPUTS_TO_EXTRACT_PATCHES_FACTOR) - inputs = _random_tensor_gather(inputs, max_size) - - # TODO(b/64144716): there is potential here for a big savings in terms of - # memory use. - if self._extract_patches_fn in [None, "extract_convolution_patches"]: - patches = utils.extract_convolution_patches( - inputs, - self._filter_shape, - padding=self._padding, - strides=self._strides, - dilation_rate=self._dilation_rate, - data_format=self._data_format) - - elif self._extract_patches_fn == "extract_image_patches": - assert inputs.shape.ndims == 4 - assert len(self._filter_shape) == 4 - assert len(self._strides) == 4, self._strides - if self._dilation_rate is None: - rates = [1, 1, 1, 1] - else: - rates = self._dilation_rate - assert len(rates) == 4 - assert rates[0] == rates[-1] == 1 - patches = array_ops.extract_image_patches( - inputs, - ksizes=[1] + list(self._filter_shape[0:-2]) + [1], - strides=self._strides, - rates=rates, - padding=self._padding) - - elif self._extract_patches_fn == "extract_pointwise_conv2d_patches": - assert self._strides in [None, [1, 1, 1, 1], (1, 1, 1, 1)] - assert self._filter_shape[0] == self._filter_shape[1] == 1 - patches = utils.extract_pointwise_conv2d_patches( - inputs, self._filter_shape, data_format=None) - - else: - raise NotImplementedError(self._extract_patches_fn) - - flatten_size = np.prod(self._filter_shape[0:-1]) - # patches_flat below is the matrix [[A_l]] from the KFC paper (tilde - # omitted over A for clarity). It has shape M|T| x J|Delta| (eq. 14), - # where M = minibatch size, |T| = number of spatial locations, - # |Delta| = number of spatial offsets, and J = number of input maps - # for convolutional layer l. - patches_flat = array_ops.reshape(patches, [-1, flatten_size]) - - # We append a homogenous coordinate to patches_flat if the layer has - # bias parameters. This gives us [[A_l]]_H from the paper. - if self._sub_sample_patches: - patches_flat = _subsample_for_cov_computation(patches_flat) - - if self._has_bias: - patches_flat = append_homog(patches_flat) - # We call compute_cov without passing in a normalizer. compute_cov uses - # the first dimension of patches_flat i.e. M|T| as the normalizer by - # default. Hence we end up computing 1/M|T| * [[A_l]]^T [[A_l]], with - # shape J|Delta| x J|Delta|. This is related to hat{Omega}_l from - # the paper but has a different scale here for consistency with - # ConvOutputKroneckerFactor. - # (Tilde omitted over A for clarity.) - return compute_cov(patches_flat) - - def _get_data_device(self, tower): - return self._inputs[tower].device - - -class ConvOutputKroneckerFactor(DenseSquareMatrixFactor): - r"""Kronecker factor for the output side of a convolutional layer. - - Estimates E[ ds ds^T ] where s is the preactivations of a convolutional layer - given example x and ds = (d / d s) log(p(y|x, w)). Expectation is taken over - all examples and locations. - - Equivalent to Gamma in https://arxiv.org/abs/1602.01407 for details. See - Section 3.1 Estimating the factors. - """ - - def __init__(self, outputs_grads, data_format=None): - """Initializes ConvOutputKroneckerFactor. - - Args: - outputs_grads: List of list of Tensors. Each Tensor is of shape - [batch_size, ..spatial_input_size.., out_channels]. First list index - is source, the second is tower. - data_format: None or str. Format of outputs_grads. - - Raises: - ValueError: If channels are not final dimension. - """ - if not utils.is_data_format_channel_last(data_format): - raise ValueError("Channel must be last.") - self._out_channels = outputs_grads[0][0].shape.as_list()[-1] - self._outputs_grads = outputs_grads - super(ConvOutputKroneckerFactor, self).__init__() - - @property - def _var_scope(self): - return "ff_convoutkron_" + scope_string_from_params( - nest.flatten(self._outputs_grads)) - - @property - def _cov_shape(self): - size = self._out_channels - return [size, size] - - @property - def _num_sources(self): - return len(self._outputs_grads) - - @property - def _num_towers(self): - return len(self._outputs_grads[0]) - - @property - def _dtype(self): - return self._outputs_grads[0][0].dtype - - def _compute_new_cov(self, source, tower): - outputs_grad = self._outputs_grads[source][tower] - - # reshaped_tensor below is the matrix DS_l defined in the KFC paper - # (tilde omitted over S for clarity). It has shape M|T| x I, where - # M = minibatch size, |T| = number of spatial locations, and - # I = number of output maps for convolutional layer l. - reshaped_tensor = array_ops.reshape(outputs_grad, [-1, self._out_channels]) - # Following the reasoning in ConvInputKroneckerFactor._compute_new_cov, - # compute_cov here returns 1/M|T| * DS_l^T DS_l = hat{Gamma}_l - # as defined in the paper, with shape I x I. - # (Tilde omitted over S for clarity.) - return compute_cov(reshaped_tensor) - - def _get_data_device(self, tower): - return self._outputs_grads[0][tower].device - - -class FullyConnectedMultiKF(FullyConnectedKroneckerFactor): - """Kronecker factor for a fully connected layer used multiple times.""" - - def __init__(self, - tensors, - num_uses=None, - has_bias=False): - """Constructs a new `FullyConnectedMultiKF`. - - Args: - tensors: List of list of Tensors of shape, each of shape - [num_uses * batch_size, n], and is a reshape version of a Tensor of - shape [num_uses, batch_size, n]. Each of these tensors is usually a - layer's inputs or its output's gradients. The first list index is - sources, the second is towers. - num_uses: int. The number of time-steps / uses. - has_bias: bool. If True, '1' is appended to each row. - """ - - self._num_uses = num_uses - - self._cov_dt1 = None - self._make_cov_dt1 = False - self._option1quants_by_damping = {} - self._option2quants_by_damping = {} - self._option1quants_registrations = set() - self._option2quants_registrations = set() - - super(FullyConnectedMultiKF, self).__init__(tensors=tensors, - has_bias=has_bias) - - @property - def _num_timesteps(self): - return self._num_uses - - @property - def _var_scope(self): - return "ff_fc_multi_" + scope_string_from_params( - tuple(nest.flatten(self._tensors)) - + (self._num_timesteps, self._has_bias,)) - - def make_covariance_update_op(self, ema_decay): - - op = super(FullyConnectedMultiKF, self).make_covariance_update_op(ema_decay) - - if self._cov_dt1 is not None: - new_cov_dt1_contribs = [] - for source in range(self._num_sources): - for tower in range(self._num_towers): - with place_on_device(self._get_data_device(tower)): - new_cov_dt1_contribs.append(self._compute_new_cov_dt1(source, - tower)) - - new_cov_dt1 = (math_ops.add_n(new_cov_dt1_contribs) - / float(self._num_towers)) - - # See comments in FisherFactor.make_covariance_update_op() for details. - if utils.on_tpu(): - new_cov_dt1 = utils.cross_replica_mean(new_cov_dt1) - - op2 = moving_averages.assign_moving_average( - self._cov_dt1, new_cov_dt1, ema_decay, zero_debias=ZERO_DEBIAS) - - # TODO(b/69112164): - # It's important that _cov and _cov_dt1 remain consistent with each - # other while the inverse ops are happening. How can we ensure this? - # We will need to add explicit synchronization for this to - # work with asynchronous training. - op = control_flow_ops.group(op, op2) - - return op - - def _compute_new_cov_dt1(self, source, tower): # pylint: disable=missing-docstring - tensor = self._tensors[source][tower] - if self._has_bias: - # This appending is technically done twice (the other time is for - # _compute_new_cov()) - tensor = append_homog(tensor) - - total_len = array_ops.shape(tensor)[0] - batch_size = total_len // self._num_timesteps - - tensor_present = tensor[:-batch_size, :] - tensor_future = tensor[batch_size:, :] - - # We specify a normalizer for this computation to ensure a PSD Fisher - # block estimate. This is equivalent to padding with zeros, as was done - # in Section B.2 of the appendix. - return compute_cov( - tensor_future, tensor_right=tensor_present, normalizer=total_len) - - def _get_data_device(self, tower): - return self._tensors[0][tower].device - - @property - def _vec_shape(self): - size = self._tensors[0][0].shape[1] + self._has_bias - return [size] - - def get_option1quants(self, damping_func): - damping_id = graph_func_to_id(damping_func) - return self._option1quants_by_damping[damping_id] - - def get_option2quants(self, damping_func): - damping_id = graph_func_to_id(damping_func) - return self._option2quants_by_damping[damping_id] - - def get_cov_dt1(self): - assert self._cov_dt1 is not None - return self._cov_dt1 - - def register_cov_dt1(self): - self._make_cov_dt1 = True - - def instantiate_cov_variables(self): - super(FullyConnectedMultiKF, self).instantiate_cov_variables() - assert self._cov_dt1 is None - if self._make_cov_dt1: - with variable_scope.variable_scope(self._var_scope): - self._cov_dt1 = variable_scope.get_variable( - "cov_dt1", - initializer=init_ops.zeros_initializer, - shape=self._cov_shape, - trainable=False, - dtype=self._dtype) - - def register_option1quants(self, damping_func): - damping_id = self._register_damping(damping_func) - if damping_id not in self._option1quants_registrations: - self._option1quants_registrations.add(damping_id) - - def register_option2quants(self, damping_func): - damping_id = self._register_damping(damping_func) - if damping_id not in self._option2quants_registrations: - self._option2quants_registrations.add(damping_id) - - def instantiate_inv_variables(self): - super(FullyConnectedMultiKF, self).instantiate_inv_variables() - - for damping_id in self._option1quants_registrations: - damping_func = self._damping_funcs_by_id[damping_id] - damping_string = graph_func_to_string(damping_func) - # It's questionable as to whether we should initialize with stuff like - # this at all. Ideally these values should never be used until they are - # updated at least once. - with variable_scope.variable_scope(self._var_scope): - Lmat = variable_scope.get_variable( # pylint: disable=invalid-name - "Lmat_damp{}".format(damping_string), - initializer=inverse_initializer, - shape=self._cov_shape, - trainable=False, - dtype=self._dtype) - psi = variable_scope.get_variable( - "psi_damp{}".format(damping_string), - initializer=init_ops.ones_initializer, - shape=self._vec_shape, - trainable=False, - dtype=self._dtype) - - assert damping_id not in self._option1quants_by_damping - self._option1quants_by_damping[damping_id] = (Lmat, psi) - - for damping_id in self._option2quants_registrations: - damping_func = self._damping_funcs_by_id[damping_id] - damping_string = graph_func_to_string(damping_func) - # It's questionable as to whether we should initialize with stuff like - # this at all. Ideally these values should never be used until they are - # updated at least once. - with variable_scope.variable_scope(self._var_scope): - Pmat = variable_scope.get_variable( # pylint: disable=invalid-name - "Lmat_damp{}".format(damping_string), - initializer=inverse_initializer, - shape=self._cov_shape, - trainable=False, - dtype=self._dtype) - Kmat = variable_scope.get_variable( # pylint: disable=invalid-name - "Kmat_damp{}".format(damping_string), - initializer=inverse_initializer, - shape=self._cov_shape, - trainable=False, - dtype=self._dtype) - mu = variable_scope.get_variable( - "mu_damp{}".format(damping_string), - initializer=init_ops.ones_initializer, - shape=self._vec_shape, - trainable=False, - dtype=self._dtype) - - assert damping_id not in self._option2quants_by_damping - self._option2quants_by_damping[damping_id] = (Pmat, Kmat, mu) - - def make_inverse_update_ops(self): - """Create and return update ops corresponding to registered computations.""" - # TODO(b/69918258): Add correctness tests for this method. - # pylint: disable=invalid-name - - ops = [] - - if (len(self._option1quants_by_damping) + - len(self._option2quants_by_damping)): - - # Note that C0 and C1 are stand-ins for A0 and A1, or G0 and G1, from - # the pseudo-code in the original paper. Because the computations for - # the A and G case are essentially the same they can both be performed by - # the same class (this one). - - C1 = self.get_cov_dt1() - - # Get the eigendecomposition of C0 (= self.get_cov()) - eigen_e, eigen_V = self.get_eigendecomp() - - # TODO(b/69678661): Note, there is an implicit assumption here that C1 - # and C0 (as represented here by its eigen-decomp) are consistent. This - # could fail to be the case if self._cov and self._cov_dt1 are not updated - # consistently, or are somehow read between or during the cov updates. - # Can this possibly happen? Is there a way to prevent it? - - for damping_id, (Lmat_var, - psi_var) in self._option1quants_by_damping.items(): - - damping = self._damping_funcs_by_id[damping_id]() - damping = math_ops.cast(damping, self._dtype) - - invsqrtC0 = math_ops.matmul( - eigen_V * (eigen_e + damping)**(-0.5), eigen_V, transpose_b=True) - - # Might need to enforce symmetry lost due to numerical issues. - invsqrtC0 = (invsqrtC0 + array_ops.transpose(invsqrtC0)) / 2.0 - - # The following line imposes the symmetry assumed by "Option 1" on C1. - # Strangely the code can work okay with this line commented out, - # depending on how psd_eig is defined. I'm not sure why. - C1 = (C1 + array_ops.transpose(C1)) / 2.0 - - # hPsi = C0^(-1/2) * C1 * C0^(-1/2) (hPsi means hat{Psi}) - hPsi = math_ops.matmul(math_ops.matmul(invsqrtC0, C1), invsqrtC0) - - # Compute the decomposition U*diag(psi)*U^T = hPsi - psi, U = utils.posdef_eig(hPsi) - - # L = C0^(-1/2) * U - Lmat = math_ops.matmul(invsqrtC0, U) - - ops.append(Lmat_var.assign(Lmat)) - ops.append(psi_var.assign(psi)) - - for damping_id, (Pmat_var, Kmat_var, - mu_var) in self._option2quants_by_damping.items(): - - damping = self._damping_funcs_by_id[damping_id]() - damping = math_ops.cast(damping, self._dtype) - - # compute C0^(-1/2) - invsqrtC0 = math_ops.matmul( - eigen_V * (eigen_e + damping)**(-0.5), eigen_V, transpose_b=True) - - # Might need to enforce symmetry lost due to numerical issues. - invsqrtC0 = (invsqrtC0 + array_ops.transpose(invsqrtC0)) / 2.0 - - # Compute the product C0^(-1/2) * C1 - invsqrtC0C1 = math_ops.matmul(invsqrtC0, C1) - - # hPsi = C0^(-1/2) * C1 * C0^(-1/2) (hPsi means hat{Psi}) - hPsi = math_ops.matmul(invsqrtC0C1, invsqrtC0) - - # Compute the decomposition E*diag(mu)*E^T = hPsi^T * hPsi - # Note that we using the notation mu instead of "m" for the eigenvalues. - # Instead of computing the product hPsi^T * hPsi and then doing an - # eigen-decomposition of this we just compute the SVD of hPsi and then - # square the singular values to get the eigenvalues. For a justification - # of this approach, see: - # https://en.wikipedia.org/wiki/Singular-value_decomposition#Relation_to_eigenvalue_decomposition - sqrtmu, _, E = linalg_ops.svd(hPsi) - mu = math_ops.square(sqrtmu) - - # Mathematically, the eigenvalues should not should not exceed 1.0, but - # due to numerical issues, or possible issues with inconsistent - # values of C1 and (the eigen-decomposition of) C0 they might. So - # we enforce this condition. - mu = math_ops.minimum(mu, 1.0) - - # P = (C0^(-1/2) * C1)^T * C0^(-1/2) = C_1^T * C_0^(-1) - Pmat = math_ops.matmul(invsqrtC0C1, invsqrtC0, transpose_a=True) - - # K = C_0^(-1/2) * E - Kmat = math_ops.matmul(invsqrtC0, E) - - ops.append(Pmat_var.assign(Pmat)) - ops.append(Kmat_var.assign(Kmat)) - ops.append(mu_var.assign(mu)) - - ops += super(FullyConnectedMultiKF, self).make_inverse_update_ops() - return [control_flow_ops.group(*ops)] - - # pylint: enable=invalid-name diff --git a/tensorflow/contrib/kfac/python/ops/fisher_factors_lib.py b/tensorflow/contrib/kfac/python/ops/fisher_factors_lib.py deleted file mode 100644 index 2d8e378a93..0000000000 --- a/tensorflow/contrib/kfac/python/ops/fisher_factors_lib.py +++ /dev/null @@ -1,38 +0,0 @@ -# Copyright 2017 The TensorFlow Authors. All Rights Reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -# ============================================================================== -"""FisherFactor definitions.""" - -from __future__ import absolute_import -from __future__ import division -from __future__ import print_function - -# pylint: disable=unused-import,line-too-long,wildcard-import -from tensorflow.contrib.kfac.python.ops.fisher_factors import * -from tensorflow.python.util.all_util import remove_undocumented -# pylint: enable=unused-import,line-too-long,wildcard-import - -_allowed_symbols = [ - "inverse_initializer", "covariance_initializer", - "diagonal_covariance_initializer", "scope_string_from_params", - "scope_string_from_name", "scalar_or_tensor_to_string", "FisherFactor", - "InverseProvidingFactor", "FullFactor", "DiagonalFactor", - "NaiveDiagonalFactor", "EmbeddingInputKroneckerFactor", - "FullyConnectedDiagonalFactor", "FullyConnectedKroneckerFactor", - "ConvInputKroneckerFactor", "ConvOutputKroneckerFactor", - "ConvDiagonalFactor", "set_global_constants", "maybe_colocate_with", - "compute_cov", "append_homog" -] - -remove_undocumented(__name__, allowed_exception_list=_allowed_symbols) diff --git a/tensorflow/contrib/kfac/python/ops/layer_collection.py b/tensorflow/contrib/kfac/python/ops/layer_collection.py deleted file mode 100644 index 43aa713edc..0000000000 --- a/tensorflow/contrib/kfac/python/ops/layer_collection.py +++ /dev/null @@ -1,1269 +0,0 @@ -# Copyright 2017 The TensorFlow Authors. All Rights Reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -# ============================================================================== -"""Registry for layers and their parameters/variables. - -This represents the collection of all layers in the approximate Fisher -information matrix to which a particular FisherBlock may belong. That is, we -might have several layer collections for one TF graph (if we have multiple K-FAC -optimizers being used, for example.) -""" - -from __future__ import absolute_import -from __future__ import division -from __future__ import print_function - -from collections import defaultdict -from collections import OrderedDict -from contextlib import contextmanager -from functools import partial -import warnings - -import math -import six - -from tensorflow.contrib.kfac.python.ops import fisher_blocks as fb -from tensorflow.contrib.kfac.python.ops import loss_functions as lf -from tensorflow.contrib.kfac.python.ops import utils -from tensorflow.python.framework import ops -from tensorflow.python.ops import math_ops -from tensorflow.python.ops import variable_scope -from tensorflow.python.util import nest - -# Names for various approximations that can be requested for Fisher blocks. -APPROX_KRONECKER_NAME = "kron" -APPROX_DIAGONAL_NAME = "diagonal" -APPROX_FULL_NAME = "full" - -_GENERIC_APPROX_TO_BLOCK_TYPES = { - APPROX_FULL_NAME: fb.FullFB, - APPROX_DIAGONAL_NAME: fb.NaiveDiagonalFB, -} - -_FULLY_CONNECTED_APPROX_TO_BLOCK_TYPES = { - APPROX_KRONECKER_NAME: fb.FullyConnectedKFACBasicFB, - APPROX_DIAGONAL_NAME: fb.FullyConnectedDiagonalFB, -} - -_CONV2D_APPROX_TO_BLOCK_TYPES = { - APPROX_KRONECKER_NAME: fb.ConvKFCBasicFB, - APPROX_DIAGONAL_NAME: fb.ConvDiagonalFB, -} - -_EMBEDDING_APPROX_TO_BLOCK_TYPES = { - APPROX_KRONECKER_NAME: fb.EmbeddingKFACFB -} - -APPROX_KRONECKER_INDEP_NAME = "kron_indep" -APPROX_KRONECKER_SERIES_1_NAME = "kron_series_1" -APPROX_KRONECKER_SERIES_2_NAME = "kron_series_2" - -_FULLY_CONNECTED_MULTI_APPROX_TO_BLOCK_TYPES = { - APPROX_KRONECKER_INDEP_NAME: fb.FullyConnectedMultiIndepFB, - APPROX_KRONECKER_SERIES_1_NAME: partial(fb.FullyConnectedSeriesFB, - option=1), - APPROX_KRONECKER_SERIES_2_NAME: partial(fb.FullyConnectedSeriesFB, - option=2) -} - -_CONV2D_MULTI_APPROX_TO_BLOCK_TYPES = { - APPROX_KRONECKER_INDEP_NAME: fb.ConvKFCBasicMultiIndepFB -} - -_EMBEDDING_MULTI_APPROX_TO_BLOCK_TYPES = { - APPROX_KRONECKER_INDEP_NAME: fb.EmbeddingKFACMultiIndepFB -} - -# Possible value for `reuse` keyword argument. Sets `reuse` to -# tf.get_variable_scope().reuse. -VARIABLE_SCOPE = "VARIABLE_SCOPE" - -_DEFAULT_LAYER_COLLECTION = None - - -def get_default_layer_collection(): - """Get default LayerCollection.""" - if _DEFAULT_LAYER_COLLECTION is None: - raise ValueError( - "Attempted to retrieve default LayerCollection when none is set. Use " - "LayerCollection.as_default().") - - return _DEFAULT_LAYER_COLLECTION - - -def set_default_layer_collection(layer_collection): - global _DEFAULT_LAYER_COLLECTION - - if _DEFAULT_LAYER_COLLECTION is not None and layer_collection is not None: - raise ValueError("Default LayerCollection is already set.") - - _DEFAULT_LAYER_COLLECTION = layer_collection - - -class LayerParametersDict(OrderedDict): - """An OrderedDict where keys are Tensors or tuples of Tensors. - - Ensures that no Tensor is associated with two different keys. - """ - - def __init__(self, *args, **kwargs): - self._tensors = set() - super(LayerParametersDict, self).__init__(*args, **kwargs) - - def __setitem__(self, key, value): - key = self._canonicalize_key(key) - tensors = key if isinstance(key, (tuple, list)) else (key,) - key_collisions = self._tensors.intersection(tensors) - if key_collisions: - raise ValueError("Key(s) already present: {}".format(key_collisions)) - self._tensors.update(tensors) - super(LayerParametersDict, self).__setitem__(key, value) - - def __delitem__(self, key): - key = self._canonicalize_key(key) - self._tensors.remove(key) - super(LayerParametersDict, self).__delitem__(key) - - def __getitem__(self, key): - key = self._canonicalize_key(key) - return super(LayerParametersDict, self).__getitem__(key) - - def __contains__(self, key): - key = self._canonicalize_key(key) - return super(LayerParametersDict, self).__contains__(key) - - def _canonicalize_key(self, key): - if isinstance(key, (list, tuple)): - return tuple(key) - return key - - -# TODO(b/68034464): add capability for LayerCollection to be "finalized" -# and do this when it gets used by FisherEstimator / KfacOptimizer. - - -class LayerCollection(object): - """Registry of information about layers and losses. - - Note that you need to create a new one of these for each MatrixEstimator or - KfacOptimizer. - - Attributes: - fisher_blocks: a LayersParamsDict (subclass of OrderedDict) mapping layer - parameters (Tensors or tuples of Tensors) to FisherBlock instances. - fisher_factors: an OrderedDict mapping tuples to FisherFactor instances. - losses: a list of LossFunction objects. The loss to be optimized is their - sum. - loss_colocation_ops: ops to colocate loss function evaluations with. These - will typically be the inputs to the losses. - """ - - def __init__(self, - graph=None, - name="LayerCollection"): - warnings.warn( - "tf.contrib.kfac is deprecated and will be removed by 2018-11-01. " - "Use https://pypi.python.org/pypi/kfac instead.") - self.fisher_blocks = LayerParametersDict() - self.fisher_factors = OrderedDict() - self._linked_parameters = dict( - ) # dict mapping sets of variables to optionally specified approximations. - self._graph = graph or ops.get_default_graph() - self._loss_dict = {} # {str: LossFunction} - self._subgraph = None - self._default_generic_approximation = APPROX_DIAGONAL_NAME - self._default_embedding_approximation = APPROX_KRONECKER_NAME - self._default_fully_connected_approximation = APPROX_KRONECKER_NAME - self._default_conv2d_approximation = APPROX_KRONECKER_NAME - self._default_fully_connected_multi_approximation = ( - APPROX_KRONECKER_INDEP_NAME) - self._default_conv2d_multi_approximation = ( - APPROX_KRONECKER_INDEP_NAME) - self._default_embedding_multi_approximation = APPROX_KRONECKER_INDEP_NAME - self.loss_colocation_ops = {} - self._vars_to_uses = defaultdict(lambda: 0) - - with variable_scope.variable_scope(None, default_name=name) as scope: - self._var_scope = scope.name - - @property - def losses(self): - """Tuple of LossFunction objects registered with this LayerCollection.""" - return nest.flatten(self.towers_by_loss) - - @property - def towers_by_loss(self): - """Tuple across losses of LossFunction objects registered to each tower.""" - return tuple(tuple(lst) for lst in self._loss_dict.values()) - - @property - def registered_variables(self): - """A tuple of all of the variables currently registered.""" - tuple_of_tuples = (utils.ensure_sequence(key) for key, block - in six.iteritems(self.fisher_blocks)) - flat_tuple = tuple(item for tuple_ in tuple_of_tuples for item in tuple_) - return flat_tuple - - @property - def linked_parameters(self): - """Groups of parameters with an optionally specified approximation. - - Linked parameters can be added using `define_linked_parameters`. - If an approximation is specified, then this approximation will be used - when registering a layer with exactly these parameters, unless an - approximation is specified when calling the registration function. - - Returns: - A `dict` mapping tuples of parameters to an optional string. - """ - return self._linked_parameters - - @property - def default_embedding_approximation(self): - return self._default_embedding_approximation - - def set_default_embedding_approximation(self, value): - if value != APPROX_KRONECKER_NAME: - raise ValueError( - "{} is not a valid approximation for embedding variables.".format( - value)) - self._default_embedding_approximation = value - - @property - def default_generic_approximation(self): - return self._default_generic_approximation - - def set_default_generic_approximation(self, value): - if value not in _GENERIC_APPROX_TO_BLOCK_TYPES: - raise ValueError( - "{} is not a valid approximation for generic variables.".format( - value)) - self._default_generic_approximation = value - - @property - def default_fully_connected_approximation(self): - return self._default_fully_connected_approximation - - def set_default_fully_connected_approximation(self, value): - if value not in _FULLY_CONNECTED_APPROX_TO_BLOCK_TYPES: - raise ValueError( - "{} is not a valid approximation for fully connected layers.".format( - value)) - self._default_fully_connected_approximation = value - - @property - def default_conv2d_approximation(self): - return self._default_conv2d_approximation - - def set_default_conv2d_approximation(self, value): - if value not in _CONV2D_APPROX_TO_BLOCK_TYPES: - raise ValueError( - "{} is not a valid approximation for 2d convolutional layers.".format( - value)) - self._default_conv2d_approximation = value - - @property - def default_fully_connected_multi_approximation(self): - return self._default_fully_connected_multi_approximation - - def set_default_fully_connected_multi_approximation(self, value): - if value not in _FULLY_CONNECTED_MULTI_APPROX_TO_BLOCK_TYPES: - raise ValueError("{} is not a valid approximation for a fully-connected " - "multi layer.".format(value)) - self._default_fully_connected_multi_approximation = value - - @property - def default_conv2d_multi_approximation(self): - return self._default_conv2d_multi_approximation - - @property - def default_embedding_multi_approximation(self): - return self._default_embedding_multi_approximation - - def register_block(self, layer_key, fisher_block, reuse=VARIABLE_SCOPE): - """Validates and registers the layer_key associated with the fisher_block. - - Args: - layer_key: A variable or tuple of variables. The key to check for in - existing registrations and to register if valid. - fisher_block: The associated `FisherBlock`. - reuse: Method to use for inserting new `FisherBlock's. One of True, False, - or `VARIABLE_SCOPE`. - - Raises: - ValueError: If `layer_key` was already registered and reuse is `False`, - if `layer_key` was registered with a different block type, or if - `layer_key` shares any variables with but is not equal to a previously - registered key. - KeyError: If `reuse` is `True` but `layer_key` was not previously - registered. - - Returns: - The `FisherBlock` registered under `layer_key`. If `layer_key` was already - registered, this will be the previously registered `FisherBlock`. - """ - if reuse is VARIABLE_SCOPE: - reuse = variable_scope.get_variable_scope().reuse - - if reuse is True or (reuse is variable_scope.AUTO_REUSE and - layer_key in self.fisher_blocks): - result = self.fisher_blocks[layer_key] - if type(result) != type(fisher_block): # pylint: disable=unidiomatic-typecheck - raise ValueError( - "Attempted to register FisherBlock of type %s when existing " - "FisherBlock has type %s." % (type(fisher_block), type(result))) - return result - if reuse is False and layer_key in self.fisher_blocks: - raise ValueError("FisherBlock for %s is already in LayerCollection." % - (layer_key,)) - - # Insert fisher_block into self.fisher_blocks. - if layer_key in self.fisher_blocks: - raise ValueError("Duplicate registration: {}".format(layer_key)) - # Raise an error if any variable in layer_key has been registered in any - # other blocks. - variable_to_block = { - var: (params, block) - for (params, block) in self.fisher_blocks.items() - for var in utils.ensure_sequence(params) - } - for variable in utils.ensure_sequence(layer_key): - if variable in variable_to_block: - prev_key, prev_block = variable_to_block[variable] - raise ValueError( - "Attempted to register layer_key {} with block {}, but variable {}" - " was already registered in key {} with block {}.".format( - layer_key, fisher_block, variable, prev_key, prev_block)) - self.fisher_blocks[layer_key] = fisher_block - return fisher_block - - def register_loss_function(self, - loss, - colocation_op, - base_name, - name=None, - reuse=VARIABLE_SCOPE): - """Registers a LossFunction object. - - Args: - loss: The LossFunction object. - colocation_op: The op to colocate the loss function's computations with. - base_name: The name to derive a new unique name from is the name argument - is None. - name: (OPTIONAL) str or None. Unique name for this loss function. If None, - a new name is generated. (Default: None) - reuse: (OPTIONAL) bool or str. If True, adds `loss` as an additional - tower for the existing loss function. - - Raises: - ValueError: If reuse == True and name == None. - ValueError: If reuse == True and seed != None. - KeyError: If reuse == True and no existing LossFunction with `name` found. - KeyError: If reuse == False and existing LossFunction with `name` found. - """ - - name = name or self._graph.unique_name(base_name) - - if reuse == VARIABLE_SCOPE: - reuse = variable_scope.get_variable_scope().reuse - - if reuse: - if name is None: - raise ValueError( - "If reuse is enabled, loss function's name must be set.") - - loss_list = self._loss_dict.get(name, None) - - if loss_list is None: - raise KeyError( - "Unable to find loss function named {}. Register a new loss " - "function with reuse=False.".format(name)) - else: - if name in self._loss_dict: - raise KeyError( - "Loss function named {} already exists. Set reuse=True to append " - "another tower.".format(name)) - - loss_list = [] - self._loss_dict[name] = loss_list - - loss_list.append(loss) - self.loss_colocation_ops[loss] = colocation_op - - def _get_use_count_map(self): - """Returns a dict mapping variables to their number of registrations.""" - return self._vars_to_uses - - def _add_uses(self, params, uses): - """Register additional uses by params in the graph. - - Args: - params: Variable or tuple of Variables. Parameters for a layer. - uses: int or float. Number of additional uses for these parameters. - """ - params = params if isinstance(params, (tuple, list)) else (params,) - for var in params: - self._vars_to_uses[var] += uses - - def check_registration(self, variables): - """Checks that all variable uses have been registered properly. - - Args: - variables: List of variables. - - Raises: - ValueError: If any registered variables are not included in the list. - ValueError: If any variable in the list is not registered. - ValueError: If any variable in the list is registered with the wrong - number of "uses" in the subgraph recorded (vs the number of times that - variable is actually used in the subgraph). - """ - # Note that overlapping parameters (i.e. those that share variables) will - # be caught by layer_collection.LayerParametersDict during registration. - - reg_use_map = self._get_use_count_map() - - error_messages = [] - - for var in variables: - total_uses = self.subgraph.variable_uses(var) - reg_uses = reg_use_map[var] - - if reg_uses == 0: - error_messages.append("Variable {} not registered.".format(var)) - elif (not math.isinf(reg_uses)) and reg_uses != total_uses: - error_messages.append( - "Variable {} registered with wrong number of uses ({} " - "registrations vs {} uses).".format(var, reg_uses, total_uses)) - - num_get_vars = len(reg_use_map) - - if num_get_vars > len(variables): - error_messages.append("{} registered variables were not included in list." - .format(num_get_vars - len(variables))) - - if error_messages: - error_messages = [ - "Found the following errors with variable registration:" - ] + error_messages - raise ValueError("\n\t".join(error_messages)) - - def get_blocks(self): - return self.fisher_blocks.values() - - def get_factors(self): - return self.fisher_factors.values() - - @property - def graph(self): - return self._graph - - @property - def subgraph(self): - return self._subgraph - - def define_linked_parameters(self, params, approximation=None): - """Identify a set of parameters that should be grouped together. - - During automatic graph scanning, any matches containing variables that have - been identified as part of a linked group will be filtered out unless - the match parameters are exactly equal to the ones specified in the linked - group. - - Args: - params: A variable, or a tuple or list of variables. The variables - to be linked. - approximation: Optional string specifying the type of approximation to use - for these variables. If unspecified, this layer collection's default - approximation for the layer type will be used. - - Raises: - ValueError: If the parameters were already registered in a layer or - identified as part of an incompatible group. - """ - params = frozenset(utils.ensure_sequence(params)) - - # Check if any of the variables in `params` is already in - # 'self.fisher_blocks.keys()`. - for registered_params, fisher_block in self.fisher_blocks.items(): - registered_params_set = set(utils.ensure_sequence(registered_params)) - for variable in params: - if (variable in registered_params_set and - params != registered_params_set): - raise ValueError( - "Can`t link parameters {}, variable {} was already registered in " - "group {} with layer {}".format(params, variable, - registered_params, fisher_block)) - - # Check if any of the variables in `params` is already in - # 'self.linked_parameters`. - for variable in params: - for other_linked_params in self.linked_parameters: - if variable in other_linked_params: - raise ValueError("Can`t link parameters {}, variable {} was already " - "linked in group {}.".format(params, variable, - other_linked_params)) - self._linked_parameters[params] = approximation - - def create_subgraph(self): - if not self.losses: - raise ValueError("Must have at least one registered loss.") - inputs_to_losses = nest.flatten(tuple(loss.inputs for loss in self.losses)) - self._subgraph = utils.SubGraph(inputs_to_losses) - - def eval_losses(self): - """Return evaluated losses (colocated with inputs to losses).""" - evals = [] - for loss in self.losses: - with ops.colocate_with(self.loss_colocation_ops[loss]): - evals.append(loss.evaluate()) - return evals - - def eval_losses_on_samples(self): - """Return losses evaluated on samples (colocated with inputs to losses).""" - evals = [] - for loss in self.losses: - with ops.colocate_with(self.loss_colocation_ops[loss]): - evals.append(loss.evaluate_on_sample()) - return evals - - def total_loss(self): - return math_ops.add_n(self.eval_losses()) - - def total_sampled_loss(self): - return math_ops.add_n(self.eval_losses_on_samples()) - - def _get_linked_approx(self, params): - """If params were linked, return their specified approximation.""" - params_set = frozenset(utils.ensure_sequence(params)) - if params_set in self.linked_parameters: - return self.linked_parameters[params_set] - else: - return None - - def _get_block_type(self, params, approx, default, approx_to_type): - if approx is None: - approx = self._get_linked_approx(params) - if approx is None: - approx = default - - if approx not in approx_to_type: - raise ValueError("Bad value {} for approx.".format(approx)) - - return approx_to_type[approx], approx - - def register_embedding(self, - params, - inputs, - outputs, - approx=None, - reuse=VARIABLE_SCOPE): - """Registers an embedding layer. - - Args: - params: Embedding matrix of shape [vocab_size, embedding_size]. - inputs: Tensor of shape [batch_size, input_size] and dtype int32. Indices - into embedding matrix. - outputs: Tensor of shape [batch_size, embedding_size]. Outputs - produced by layer. - approx: str or None. If not None must be "kron". The Fisher - approximation to use. If None the default value is used. (Default: None) - reuse: bool or str. If True, this adds `inputs` and `outputs` as an - additional mini-batch/tower of data to use when estimating the Fisher - block for this layer (which must have already been registered). If - "VARIABLE_SCOPE", use tf.get_variable_scope().reuse. - (Default: "VARIABLE_SCOPE") - - Raises: - ValueError: For improper value to `approx`. - KeyError: If reuse == True but no FisherBlock found for `params`. - ValueError: If reuse == True and FisherBlock found but of the wrong type. - """ - block_type, approx = self._get_block_type( - params, approx, self.default_embedding_approximation, - _EMBEDDING_APPROX_TO_BLOCK_TYPES) - - if isinstance(params, (tuple, list)): - raise ValueError("Bias not supported.") - vocab_size = int(params.shape[0]) - block = self.register_block( - params, block_type(self, vocab_size), reuse=reuse) - block.register_additional_tower(inputs, outputs) - - self._add_uses(params, 1) - - def register_fully_connected(self, - params, - inputs, - outputs, - approx=None, - reuse=VARIABLE_SCOPE): - """Registers a fully connected layer. - - Args: - params: Tensor or 2-tuple of Tensors corresponding to weight and bias of - this layer. Weight matrix should have shape [input_size, output_size]. - Bias should have shape [output_size]. - inputs: Tensor of shape [batch_size, input_size]. Inputs to layer. - outputs: Tensor of shape [batch_size, output_size]. Outputs - produced by layer. - approx: str or None. If not None must be one of "kron" or "diagonal". - The Fisher approximation to use. If None the default value is used. - (Default: None) - reuse: bool or str. If True, this adds `inputs` and `outputs` as an - additional mini-batch/tower of data to use when estimating the Fisher - block for this layer (which must have already been registered). If - "VARIABLE_SCOPE", use tf.get_variable_scope().reuse. - (Default: "VARIABLE_SCOPE") - - Raises: - ValueError: For improper value to `approx`. - KeyError: If reuse == True but no FisherBlock found for `params`. - ValueError: If reuse == True and FisherBlock found but of the wrong type. - """ - - block_type, approx = self._get_block_type( - params, approx, self.default_fully_connected_approximation, - _FULLY_CONNECTED_APPROX_TO_BLOCK_TYPES) - - has_bias = isinstance(params, (tuple, list)) - block = self.register_block(params, block_type(self, has_bias=has_bias), - reuse=reuse) - block.register_additional_tower(inputs, outputs) - - self._add_uses(params, 1) - - def register_conv2d(self, - params, - strides, - padding, - inputs, - outputs, - data_format=None, - dilations=None, - approx=None, - reuse=VARIABLE_SCOPE): - """Registers a call to tf.nn.conv2d(). - - Args: - params: Tensor or 2-tuple of Tensors corresponding to weight and bias of - this layer. Weight matrix should have shape [kernel_height, - kernel_width, in_channels, out_channels]. Bias should have shape - [out_channels]. - strides: List of 4 ints. Strides for convolution kernel. - padding: string. see tf.nn.conv2d for valid values. - inputs: Tensor of shape [batch_size, height, width, in_channels]. Inputs - to layer. - outputs: Tensor of shape [batch_size, height, width, out_channels]. - Output produced by layer. - data_format: str or None. Format of data. - dilations: List of 4 ints. Dilations along each dimension. - approx: str or None. If not None must be one of "kron" or "diagonal". - The Fisher approximation to use. If None the default value is used. - (Default: None) - reuse: bool or str. If True, this adds `inputs` and `outputs` as an - additional mini-batch/tower of data to use when estimating the Fisher - block for this layer (which must have already been registered). If - "VARIABLE_SCOPE", use tf.get_variable_scope().reuse. - (Default: "VARIABLE_SCOPE") - - Raises: - ValueError: For improper value to `approx`. - KeyError: If reuse == True but no FisherBlock found for `params`. - ValueError: If reuse == True and FisherBlock found but of the wrong type. - """ - - block_type, approx = self._get_block_type( - params, approx, self.default_conv2d_approximation, - _CONV2D_APPROX_TO_BLOCK_TYPES) - - # It feels bad to pass in configuration that has to do with the internal - # implementation. And then we can`t use the same constructor for both - # anymore and are thus forced to use this ugly if-statement. - # TODO(b/74793309): Clean this up? - if approx == APPROX_KRONECKER_NAME: - block = self.register_block( - params, - block_type( - layer_collection=self, - params=params, - padding=padding, - strides=strides, - data_format=data_format, - dilation_rate=dilations, - extract_patches_fn="extract_image_patches"), - reuse=reuse) - elif approx == APPROX_DIAGONAL_NAME: - assert strides[0] == strides[-1] == 1 - block = self.register_block( - params, - block_type( - layer_collection=self, - params=params, - padding=padding, - strides=strides, - dilations=dilations, - data_format=data_format), - reuse=reuse) - else: - raise NotImplementedError(approx) - - block.register_additional_tower(inputs, outputs) - - self._add_uses(params, 1) - - def register_convolution(self, - params, - inputs, - outputs, - padding, - strides=None, - dilation_rate=None, - data_format=None, - approx=None, - reuse=VARIABLE_SCOPE): - """Register a call to tf.nn.convolution(). - - Args: - params: Tensor or 2-tuple of Tensors corresponding to weight and bias of - this layer. Weight matrix should have shape [..filter_spatial_size.., - in_channels, out_channels]. Bias should have shape [out_channels]. - inputs: Tensor of shape [batch_size, ..input_spatial_size.., in_channels]. - Inputs to layer. - outputs: Tensor of shape [batch_size, ..output_spatial_size.., - out_channels]. Output produced by layer. - padding: string. see tf.nn.conv2d for valid values. - strides: List of ints of length len(..input_spatial_size..). Strides for - convolution kernel in spatial dimensions. - dilation_rate: List of ints of length len(..input_spatial_size..). - Dilations along spatial dimension. - data_format: str or None. Format of data. - approx: str or None. If not None must be one of "kron" or "diagonal". - The Fisher approximation to use. If None the default value is used. - (Default: None) - reuse: bool or str. If True, this adds `inputs` and `outputs` as an - additional mini-batch/tower of data to use when estimating the Fisher - block for this layer (which must have already been registered). If - "VARIABLE_SCOPE", use tf.get_variable_scope().reuse. - (Default: "VARIABLE_SCOPE") - - Raises: - ValueError: For improper value to `approx`. - KeyError: If reuse == True but no FisherBlock found for `params`. - ValueError: If reuse == True and FisherBlock found but of the wrong type. - """ - # TODO(b/74793309): Have this use _get_block_type like the other - # registration functions? - assert approx is None or approx == APPROX_KRONECKER_NAME - - block = self.register_block( - params, - fb.ConvKFCBasicFB( - layer_collection=self, - params=params, - padding=padding, - strides=strides, - dilation_rate=dilation_rate, - data_format=data_format), - reuse=reuse) - block.register_additional_tower(inputs, outputs) - - self._add_uses(params, 1) - - def register_depthwise_conv2d(self, - params, - inputs, - outputs, - strides, - padding, - rate=None, - data_format=None, - approx=None, - reuse=VARIABLE_SCOPE): - """Register a call to tf.nn.depthwise_conv2d(). - - Args: - params: 4-D Tensor of shape [filter_height, filter_width, - in_channels, channel_multiplier]. Convolutional filter. - inputs: Tensor of shape [batch_size, input_height, input_width, - in_channels]. Inputs to layer. - outputs: Tensor of shape [batch_size, output_height, output_width, - in_channels * channel_multiplier]. Output produced by depthwise conv2d. - strides: List of ints of length 4. Strides along all dimensions. - padding: string. see tf.nn.conv2d for valid values. - rate: None or List of ints of length 2. Dilation rates in spatial - dimensions. - data_format: str or None. Format of data. - approx: str or None. If not None must "diagonal". The Fisher - approximation to use. If None the default value is used. (Default: None) - reuse: bool or str. If True, this adds `inputs` and `outputs` as an - additional mini-batch/tower of data to use when estimating the Fisher - block for this layer (which must have already been registered). If - "VARIABLE_SCOPE", use tf.get_variable_scope().reuse. - (Default: "VARIABLE_SCOPE") - - Raises: - ValueError: For improper value to `approx`. - KeyError: If reuse == True but no FisherBlock found for `params`. - ValueError: If reuse == True and FisherBlock found but of the wrong type. - """ - # TODO(b/74793309): Have this use _get_block_type like the other - # registration functions? - assert approx is None or approx == APPROX_DIAGONAL_NAME - assert data_format in [None, "NHWC"] - - block = self.register_block( - params, - fb.DepthwiseConvDiagonalFB( - layer_collection=self, - params=params, - strides=strides, - padding=padding, - rate=rate, - data_format=data_format), - reuse=reuse) - block.register_additional_tower(inputs, outputs) - - self._add_uses(params, 1) - - def register_separable_conv2d(self, - depthwise_params, - pointwise_params, - inputs, - depthwise_outputs, - pointwise_outputs, - strides, - padding, - rate=None, - data_format=None, - approx=None, - reuse=VARIABLE_SCOPE): - """Register a call to tf.nn.separable_conv2d(). - - Note: This requires access to intermediate outputs between depthwise and - pointwise convolutions. - - Args: - depthwise_params: 4-D Tensor of shape [filter_height, filter_width, - in_channels, channel_multiplier]. Filter for depthwise conv2d. - pointwise_params: 4-D Tensor of shape [1, 1, in_channels * - channel_multiplier, out_channels]. Filter for pointwise conv2d. - inputs: Tensor of shape [batch_size, input_height, input_width, - in_channels]. Inputs to layer. - depthwise_outputs: Tensor of shape [batch_size, output_height, - output_width, in_channels * channel_multiplier]. Output produced by - depthwise conv2d. - pointwise_outputs: Tensor of shape [batch_size, output_height, - output_width, out_channels]. Output produced by pointwise conv2d. - strides: List of ints of length 4. Strides for depthwise conv2d kernel in - all dimensions. - padding: string. see tf.nn.conv2d for valid values. - rate: None or List of ints of length 2. Dilation rate of depthwise conv2d - kernel in spatial dimensions. - data_format: str or None. Format of data. - approx: str or None. If not None must be one of "kron" or "diagonal". - The Fisher approximation to use. If None the default value is used. - (Default: None) - reuse: bool or str. If True, this adds `inputs` and `outputs` as an - additional mini-batch/tower of data to use when estimating the Fisher - block for this layer (which must have already been registered). If - "VARIABLE_SCOPE", use tf.get_variable_scope().reuse. - (Default: "VARIABLE_SCOPE") - - Raises: - ValueError: For improper value to `approx`. - KeyError: If reuse == True but no FisherBlock found for `params`. - ValueError: If reuse == True and FisherBlock found but of the wrong type. - """ - self.register_depthwise_conv2d( - params=depthwise_params, - inputs=inputs, - outputs=depthwise_outputs, - strides=strides, - padding=padding, - rate=rate, - data_format=data_format, - approx=APPROX_DIAGONAL_NAME, - reuse=reuse) - - self.register_conv2d( - params=pointwise_params, - inputs=depthwise_outputs, - outputs=pointwise_outputs, - strides=[1, 1, 1, 1], - padding="VALID", - data_format=data_format, - approx=approx, - reuse=reuse) - - def register_generic(self, - params, - batch_size, - approx=None, - reuse=VARIABLE_SCOPE): - """Registers a generic layer. - - Args: - params: Tensor or tuple of Tensors corresponding to the parameters. - batch_size: 0-D Tensor. Size of the minibatch (for this tower). - approx: str or None. It not None, must be one of "full" or "diagonal". - The Fisher approximation to use. If None the default value is used. - (Default: None) - reuse: bool or str. If True, this adds `batch_size` to the total - mini-batch size use when estimating the Fisher block for this layer - (which must have already been registered). If "VARIABLE_SCOPE", use - tf.get_variable_scope().reuse. (Default: "VARIABLE_SCOPE") - - Raises: - ValueError: For improper value to `approx`. - KeyError: If reuse == True but no FisherBlock found for `params`. - ValueError: If reuse == True and FisherBlock found but of the wrong type. - """ - block_type, approx = self._get_block_type( - params, approx, self.default_generic_approximation, - _GENERIC_APPROX_TO_BLOCK_TYPES) - - block = self.register_block(params, block_type(self, params), reuse=reuse) - block.register_additional_tower(batch_size) - - self._add_uses(params, float("inf")) - - def register_fully_connected_multi(self, params, inputs, outputs, - num_uses=None, approx=None, - reuse=VARIABLE_SCOPE): - """Register fully connected layers with shared parameters. - - This can handle general fully-connected layers with shared parameters, but - has specialized approximations to deal with the case where there is a - meaningful linear order to the share instances (such as in an RNN). - - Args: - params: Tensor or 2-tuple of Tensors corresponding to weight and bias of - this layer. Weight matrix should have shape [input_size, output_size]. - Bias should have shape [output_size]. - inputs: A list of Tensors, each of shape [batch_size, input_size]. Inputs - to layer. The list indexes each use in the graph (which might - correspond to a "time-step" in an RNN). OR, can be single Tensor, of - shape [num_uses * batch_size , input_size], which is a reshaped version - of a Tensor of shape [num_uses, batch_size, input_size]. - outputs: A list of Tensors, the same length as `inputs`, each of shape - [batch_size, output_size]. Outputs produced by layer. The list indexes - each use in the graph (which might correspond to a "time-step" in an - RNN). Needs to correspond with the order used in `inputs`. OR, can be - a single Tensor of shape [num_uses * batch_size, output_size], which is - a reshaped version of a Tensor of shape [num_uses, batch_size, - output_size]. - num_uses: int or None. The number uses/time-steps in the graph where the - layer appears. Only needed if both inputs and outputs are given in the - single Tensor format. (Default: None) - approx: str or None. If not None, must be of "kron_indep", "kron_series_1" - or "kron_series_2". The Fisher approximation to use. If None the default - value is used. (Default: None) - reuse: bool or str. If True, this adds `inputs` and `outputs` as an - additional mini-batch/tower of data to use when estimating the Fisher - block for this layer (which must have already been registered). If - "VARIABLE_SCOPE", use tf.get_variable_scope().reuse. (Note that the - word `use` here has a completely different meaning to "use in the graph" - as it pertains to the `inputs`, `outputs`, and `num_uses` arguments.) - (Default: "VARIABLE_SCOPE") - - Raises: - ValueError: For improper value to `approx`. - """ - block_type, approx = self._get_block_type( - params, approx, self.default_fully_connected_multi_approximation, - _FULLY_CONNECTED_MULTI_APPROX_TO_BLOCK_TYPES) - - # TODO(b/70283649): something along the lines of find_canonical_output - # should be added back in here (and for the other block types, arguably). - - has_bias = isinstance(params, (tuple, list)) - block = self.register_block(params, block_type(self, has_bias=has_bias, - num_uses=num_uses), - reuse=reuse) - block.register_additional_tower(inputs, outputs) - if isinstance(inputs, (tuple, list)): - assert len(inputs) == len(outputs) - self._add_uses(params, len(inputs)) - else: - self._add_uses(params, 1) - - def register_conv2d_multi(self, - params, - strides, - padding, - inputs, - outputs, - num_uses=None, - data_format=None, - dilations=None, - approx=None, - reuse=VARIABLE_SCOPE): - """Registers convolutional layers with shared parameters. - - Args: - params: Tensor or 2-tuple of Tensors corresponding to weight and bias of - this layer. Weight matrix should have shape [kernel_height, - kernel_width, in_channels, out_channels]. Bias should have shape - [out_channels]. - strides: 1-D Tensor of length 4. Strides for convolution kernel. - padding: string. see tf.nn.conv2d for valid values. - inputs: A list of Tensors, each of shape [batch_size, height, width, - in_channels]. Inputs to layer. The list indexes each use in the graph - (which might correspond to a "time-step" in an RNN). OR, can be single - Tensor, of shape [num_uses * batch_size, height, width, in_channels], - which is a reshaped version of a Tensor of shape [num_uses, batch_size, - height, width, in_channels]. - outputs: A list of Tensors, each of shape [batch_size, height, width, - out_channels]. Output produced by layer. The list indexes each use - in the graph (which might correspond to a "time-step" in an RNN). - Needs to correspond with the order used in `inputs`. OR, can be a - single Tensor, of shape [num_uses * batch_size, height, width, - out_channels], which is a reshaped version of a Tensor of shape - [num_uses, batch_size, height, width, out_channels]. - num_uses: int or None. The number uses/time-steps in the graph where the - layer appears. Only needed if both inputs and outputs are given in the - single Tensor format. (Default: None) - data_format: str or None. Format of data. - dilations: List of 4 ints. Dilations along each dimension. - approx: str or None. If not None must by "kron_indep". The Fisher - approximation to use. If None the default value is used. - (Default: None) - reuse: bool or str. If True, this adds `inputs` and `outputs` as an - additional mini-batch/tower of data to use when estimating the Fisher - block for this layer (which must have already been registered). If - "VARIABLE_SCOPE", use tf.get_variable_scope().reuse. (Note that the - word `use` here has a completely different meaning to "use in the graph" - as it pertains to the `inputs`, `outputs`, and `num_uses` arguments.) - (Default: "VARIABLE_SCOPE") - - Raises: - ValueError: For improper value to `approx`. - KeyError: If reuse == True but no FisherBlock found for `params`. - ValueError: If reuse == True and FisherBlock found but of the wrong type. - """ - block_type, approx = self._get_block_type( - params, approx, self.default_conv2d_multi_approximation, - _CONV2D_MULTI_APPROX_TO_BLOCK_TYPES) - - block = self.register_block( - params, - block_type( - layer_collection=self, - params=params, - padding=padding, - strides=strides, - data_format=data_format, - dilation_rate=dilations, - extract_patches_fn="extract_image_patches", - num_uses=num_uses), - reuse=reuse) - - block.register_additional_tower(inputs, outputs) - if isinstance(inputs, (tuple, list)): - assert len(inputs) == len(outputs) - self._add_uses(params, len(inputs)) - else: - self._add_uses(params, 1) - - # TODO(b/74108452): change the loss registration functions names to refer - # to "loss functions" instead of distributions. Following naming convention - # of the loss function classes themselves. - - def register_embedding_multi(self, - params, - inputs, - outputs, - num_uses=None, - approx=None, - reuse=VARIABLE_SCOPE): - """Registers embedding layers with shared parameters. - - Args: - params: Embedding matrix of shape [vocab_size, embedding_size]. - inputs: A list of Tensors, each of shape [batch_size, input_size] and - dtype int32. Indices into embedding matrix. The list indexes each use - in the graph (which might correspond to a "time-step" in an RNN). - OR, can be single Tensor, of shape [num_uses*batch_size, input_size], - which is a reshaped version of a Tensor of shape [num_uses, batch_size, - input_size]. - outputs: A list of Tensors, each of shape [batch_size, embedding_size]. - Outputs produced by layer. The list indexes each use in the graph - (which might correspond to a "time-step" in an RNN). Needs to - correspond with the order used in `inputs`. OR, can be a - single Tensor, of shape [num_uses * batch_size, embedding_size], which - is a reshaped version of a Tensor of shape [num_uses, batch_size, - embedding_size]. - num_uses: int or None. The number uses/time-steps in the graph where the - layer appears. Only needed if both inputs and outputs are given in the - single Tensor format. (Default: None) - approx: str or None. If not None must by "kron_indep". The Fisher - approximation to use. If None the default value is used. - (Default: None) - reuse: bool or str. If True, this adds `inputs` and `outputs` as an - additional mini-batch/tower of data to use when estimating the Fisher - block for this layer (which must have already been registered). If - "VARIABLE_SCOPE", use tf.get_variable_scope().reuse. (Note that the - word `use` here has a completely different meaning to "use in the graph" - as it pertains to the `inputs`, `outputs`, and `num_uses` arguments.) - (Default: "VARIABLE_SCOPE") - - Raises: - ValueError: For improper value to `approx`. - KeyError: If reuse == True but no FisherBlock found for `params`. - ValueError: If reuse == True and FisherBlock found but of the wrong type. - """ - block_type, approx = self._get_block_type( - params, approx, self.default_embedding_multi_approximation, - _EMBEDDING_MULTI_APPROX_TO_BLOCK_TYPES) - - if isinstance(params, (tuple, list)): - raise ValueError("Bias not supported.") - vocab_size = int(params.shape[0]) - - block = self.register_block( - params, block_type(self, vocab_size, num_uses=num_uses), reuse=reuse) - block.register_additional_tower(inputs, outputs) - - if isinstance(inputs, (tuple, list)): - self._add_uses(params, len(inputs)) - else: - self._add_uses(params, 1) - - def register_categorical_predictive_distribution(self, - logits, - seed=None, - targets=None, - name=None, - reuse=VARIABLE_SCOPE): - """Registers a categorical predictive distribution. - - Args: - logits: The logits of the distribution (i.e. its parameters). - seed: The seed for the RNG (for debugging) (Default: None) - targets: (OPTIONAL) The targets for the loss function. Only required if - one wants to call total_loss() instead of total_sampled_loss(). - total_loss() is required, for example, to estimate the - "empirical Fisher" (instead of the true Fisher). - (Default: None) - name: (OPTIONAL) str or None. Unique name for this loss function. If None, - a new name is generated. (Default: None) - reuse: bool or str. If True, this adds `logits` as an additional - mini-batch/tower of inputs to the loss-function/predictive distribution - (which must have already been registered). If "VARIABLE_SCOPE", use - tf.get_variable_scope().reuse. (Default: "VARIABLE_SCOPE") - """ - loss = lf.CategoricalLogitsNegativeLogProbLoss(logits, targets=targets, - seed=seed) - self.register_loss_function(loss, logits, - "categorical_predictive_distribution", - name=name, reuse=reuse) - - def register_normal_predictive_distribution(self, - mean, - var=0.5, - seed=None, - targets=None, - name=None, - reuse=VARIABLE_SCOPE): - """Registers a normal predictive distribution. - - Args: - mean: The mean vector defining the distribution. - var: The variance (must be a scalar). Note that the default value of - 0.5 corresponds to a standard squared error loss (target - - prediction)**2. If your squared error loss is of the form - 0.5*(target - prediction)**2 you should use var=1.0. (Default: 0.5) - seed: The seed for the RNG (for debugging) (Default: None) - targets: (OPTIONAL) The targets for the loss function. Only required if - one wants to call total_loss() instead of total_sampled_loss(). - total_loss() is required, for example, to estimate the - "empirical Fisher" (instead of the true Fisher). - (Default: None) - name: (OPTIONAL) str or None. Unique name for this loss function. If None, - a new name is generated. (Default: None) - reuse: bool or str. If True, this adds `mean` and `var` as an additional - mini-batch/tower of inputs to the loss-function/predictive distribution - (which must have already been registered). If "VARIABLE_SCOPE", use - tf.get_variable_scope().reuse. (Default: "VARIABLE_SCOPE") - """ - loss = lf.NormalMeanNegativeLogProbLoss(mean, var, targets=targets, - seed=seed) - self.register_loss_function(loss, mean, - "normal_predictive_distribution", - name=name, reuse=reuse) - - def register_multi_bernoulli_predictive_distribution(self, - logits, - seed=None, - targets=None, - name=None, - reuse=VARIABLE_SCOPE): - """Registers a multi-Bernoulli predictive distribution. - - Args: - logits: The logits of the distribution (i.e. its parameters). - seed: The seed for the RNG (for debugging) (Default: None) - targets: (OPTIONAL) The targets for the loss function. Only required if - one wants to call total_loss() instead of total_sampled_loss(). - total_loss() is required, for example, to estimate the - "empirical Fisher" (instead of the true Fisher). - (Default: None) - name: (OPTIONAL) str or None. Unique name for this loss function. If None, - a new name is generated. (Default: None) - reuse: bool or str. If True, this adds `logits` as an additional - mini-batch/tower of inputs to the loss-function/predictive distribution - (which must have already been registered). If "VARIABLE_SCOPE", use - tf.get_variable_scope().reuse. (Default: "VARIABLE_SCOPE") - """ - loss = lf.MultiBernoulliNegativeLogProbLoss(logits, targets=targets, - seed=seed) - self.register_loss_function(loss, logits, - "multi_bernoulli_predictive_distribution", - name=name, reuse=reuse) - - def make_or_get_factor(self, cls, args): - """Insert `cls(args)` into 'self.fisher_factors` if not already present. - - Wraps constructor in `tf.variable_scope()` to ensure variables constructed - in `cls.__init__` are placed under this LayerCollection's scope. - - Args: - cls: Class that implements FisherFactor. - args: Tuple of arguments to pass into `cls's constructor. Must be - hashable. - - Returns: - Instance of `cls` found in self.fisher_factors. - """ - try: - hash(args) - except TypeError: - raise TypeError( - ("Unable to use (cls, args) = ({}, {}) as a key in " - "LayerCollection.fisher_factors. The pair cannot be hashed.").format( - cls, args)) - - key = cls, args - if key not in self.fisher_factors: - with variable_scope.variable_scope(self._var_scope): - self.fisher_factors[key] = cls(*args) - return self.fisher_factors[key] - - @contextmanager - def as_default(self): - """Sets this LayerCollection as the default.""" - set_default_layer_collection(self) - yield - set_default_layer_collection(None) diff --git a/tensorflow/contrib/kfac/python/ops/layer_collection_lib.py b/tensorflow/contrib/kfac/python/ops/layer_collection_lib.py deleted file mode 100644 index 9f46853807..0000000000 --- a/tensorflow/contrib/kfac/python/ops/layer_collection_lib.py +++ /dev/null @@ -1,46 +0,0 @@ -# Copyright 2017 The TensorFlow Authors. All Rights Reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -# ============================================================================== -"""Registry for layers and their parameters/variables. - -This represents the collection of all layers in the approximate Fisher -information matrix to which a particular FisherBlock may belong. That is, we -might have several layer collections for one TF graph (if we have multiple K-FAC -optimizers being used, for example.) -""" - -from __future__ import absolute_import -from __future__ import division -from __future__ import print_function - -# pylint: disable=unused-import,line-too-long,wildcard-import -from tensorflow.contrib.kfac.python.ops.layer_collection import * -from tensorflow.python.util.all_util import remove_undocumented -# pylint: enable=unused-import,line-too-long,wildcard-import - -_allowed_symbols = [ - "get_default_layer_collection", - "set_default_layer_collection", - "LayerParametersDict", - "LayerCollection", - "APPROX_KRONECKER_NAME", - "APPROX_DIAGONAL_NAME", - "APPROX_FULL_NAME", - "VARIABLE_SCOPE", - "APPROX_KRONECKER_INDEP_NAME", - "APPROX_KRONECKER_SERIES_1_NAME", - "APPROX_KRONECKER_SERIES_2_NAME" -] - -remove_undocumented(__name__, allowed_exception_list=_allowed_symbols) diff --git a/tensorflow/contrib/kfac/python/ops/linear_operator.py b/tensorflow/contrib/kfac/python/ops/linear_operator.py deleted file mode 100644 index 61cb955ae8..0000000000 --- a/tensorflow/contrib/kfac/python/ops/linear_operator.py +++ /dev/null @@ -1,95 +0,0 @@ -# Copyright 2018 The TensorFlow Authors. All Rights Reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -# ============================================================================== -"""SmartMatrices definitions.""" - -from __future__ import absolute_import -from __future__ import division -from __future__ import print_function - -from tensorflow.contrib.kfac.python.ops import utils -from tensorflow.python.framework import ops -from tensorflow.python.ops import math_ops -from tensorflow.python.ops.linalg import linalg -from tensorflow.python.ops.linalg import linalg_impl -from tensorflow.python.ops.linalg import linear_operator_util as lou - - -class LinearOperatorExtras(object): # pylint: disable=missing-docstring - - def matmul(self, x, adjoint=False, adjoint_arg=False, name="matmul"): - - with self._name_scope(name, values=[x]): - if isinstance(x, ops.IndexedSlices): - return self._matmul_sparse(x, adjoint=adjoint, adjoint_arg=adjoint_arg) - - x = ops.convert_to_tensor(x, name="x") - self._check_input_dtype(x) - - self_dim = -2 if adjoint else -1 - arg_dim = -1 if adjoint_arg else -2 - self.shape[self_dim].assert_is_compatible_with(x.get_shape()[arg_dim]) - - return self._matmul(x, adjoint=adjoint, adjoint_arg=adjoint_arg) - - def matmul_right(self, x, adjoint=False, adjoint_arg=False, name="matmul"): - - with self._name_scope(name, values=[x]): - - if isinstance(x, ops.IndexedSlices): - return self._matmul_right_sparse( - x, adjoint=adjoint, adjoint_arg=adjoint_arg) - - x = ops.convert_to_tensor(x, name="x") - self._check_input_dtype(x) - - self_dim = -1 if adjoint else -2 - arg_dim = -2 if adjoint_arg else -1 - self.shape[self_dim].assert_is_compatible_with(x.get_shape()[arg_dim]) - - return self._matmul_right(x, adjoint=adjoint, adjoint_arg=adjoint_arg) - - -class LinearOperatorFullMatrix(LinearOperatorExtras, - linalg.LinearOperatorFullMatrix): - - # TODO(b/78117889) Remove this definition once core LinearOperator - # has _matmul_right. - def _matmul_right(self, x, adjoint=False, adjoint_arg=False): - return lou.matmul_with_broadcast( - x, self._matrix, adjoint_a=adjoint_arg, adjoint_b=adjoint) - - def _matmul_sparse(self, x, adjoint=False, adjoint_arg=False): - raise NotImplementedError - - def _matmul_right_sparse(self, x, adjoint=False, adjoint_arg=False): - assert not adjoint and not adjoint_arg - return utils.matmul_sparse_dense(x, self._matrix) - - -class LinearOperatorDiag(LinearOperatorExtras, # pylint: disable=missing-docstring - linalg.LinearOperatorDiag): - - def _matmul_right(self, x, adjoint=False, adjoint_arg=False): - diag_mat = math_ops.conj(self._diag) if adjoint else self._diag - x = linalg_impl.adjoint(x) if adjoint_arg else x - return diag_mat * x - - def _matmul_sparse(self, x, adjoint=False, adjoint_arg=False): - diag_mat = math_ops.conj(self._diag) if adjoint else self._diag - assert not adjoint_arg - return utils.matmul_diag_sparse(diag_mat, x) - - def _matmul_right_sparse(self, x, adjoint=False, adjoint_arg=False): - raise NotImplementedError diff --git a/tensorflow/contrib/kfac/python/ops/loss_functions.py b/tensorflow/contrib/kfac/python/ops/loss_functions.py deleted file mode 100644 index c8cebc42cb..0000000000 --- a/tensorflow/contrib/kfac/python/ops/loss_functions.py +++ /dev/null @@ -1,754 +0,0 @@ -# Copyright 2017 The TensorFlow Authors. All Rights Reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -# ============================================================================== -"""Loss functions to be used by LayerCollection.""" - -from __future__ import absolute_import -from __future__ import division -from __future__ import print_function - -import abc - -import six - -from tensorflow.contrib.distributions.python.ops import onehot_categorical -from tensorflow.python.framework import tensor_shape -from tensorflow.python.ops import array_ops -from tensorflow.python.ops import math_ops -from tensorflow.python.ops.distributions import bernoulli -from tensorflow.python.ops.distributions import categorical -from tensorflow.python.ops.distributions import normal - - -@six.add_metaclass(abc.ABCMeta) -class LossFunction(object): - """Abstract base class for loss functions. - - Note that unlike typical loss functions used in neural networks these are - summed and not averaged across cases in the batch, since this is what the - users of this class (FisherEstimator and MatrixVectorProductComputer) will - be expecting. The implication of this is that you will may want to - normalize things like Fisher-vector products by the batch size when you - use this class. It depends on the use case. - """ - - @abc.abstractproperty - def targets(self): - """The targets being predicted by the model. - - Returns: - None or Tensor of appropriate shape for calling self._evaluate() on. - """ - pass - - @abc.abstractproperty - def inputs(self): - """The inputs to the loss function (excluding the targets).""" - pass - - def evaluate(self): - """Evaluate the loss function on the targets.""" - if self.targets is not None: - # We treat the targets as "constant". It's only the inputs that get - # "back-propped" through. - return self._evaluate(array_ops.stop_gradient(self.targets)) - else: - raise Exception("Cannot evaluate losses with unspecified targets.") - - @abc.abstractmethod - def _evaluate(self, targets): - """Evaluates the negative log probability of the targets. - - Args: - targets: Tensor that distribution can calculate log_prob() of. - - Returns: - negative log probability of each target, summed across all targets. - """ - pass - - @abc.abstractmethod - def multiply_hessian(self, vector): - """Right-multiply a vector by the Hessian. - - Here the 'Hessian' is the Hessian matrix (i.e. matrix of 2nd-derivatives) - of the loss function with respect to its inputs. - - Args: - vector: The vector to multiply. Must be the same shape(s) as the - 'inputs' property. - - Returns: - The vector right-multiplied by the Hessian. Will be of the same shape(s) - as the 'inputs' property. - """ - pass - - @abc.abstractmethod - def multiply_hessian_factor(self, vector): - """Right-multiply a vector by a factor B of the Hessian. - - Here the 'Hessian' is the Hessian matrix (i.e. matrix of 2nd-derivatives) - of the loss function with respect to its inputs. Typically this will be - block-diagonal across different cases in the batch, since the loss function - is typically summed across cases. - - Note that B can be any matrix satisfying B * B^T = H where H is the Hessian, - but will agree with the one used in the other methods of this class. - - Args: - vector: The vector to multiply. Must be of the shape given by the - 'hessian_factor_inner_shape' property. - - Returns: - The vector right-multiplied by B. Will be of the same shape(s) as the - 'inputs' property. - """ - pass - - @abc.abstractmethod - def multiply_hessian_factor_transpose(self, vector): - """Right-multiply a vector by the transpose of a factor B of the Hessian. - - Here the 'Hessian' is the Hessian matrix (i.e. matrix of 2nd-derivatives) - of the loss function with respect to its inputs. Typically this will be - block-diagonal across different cases in the batch, since the loss function - is typically summed across cases. - - Note that B can be any matrix satisfying B * B^T = H where H is the Hessian, - but will agree with the one used in the other methods of this class. - - Args: - vector: The vector to multiply. Must be the same shape(s) as the - 'inputs' property. - - Returns: - The vector right-multiplied by B^T. Will be of the shape given by the - 'hessian_factor_inner_shape' property. - """ - pass - - @abc.abstractmethod - def multiply_hessian_factor_replicated_one_hot(self, index): - """Right-multiply a replicated-one-hot vector by a factor B of the Hessian. - - Here the 'Hessian' is the Hessian matrix (i.e. matrix of 2nd-derivatives) - of the loss function with respect to its inputs. Typically this will be - block-diagonal across different cases in the batch, since the loss function - is typically summed across cases. - - A 'replicated-one-hot' vector means a tensor which, for each slice along the - batch dimension (assumed to be dimension 0), is 1.0 in the entry - corresponding to the given index and 0 elsewhere. - - Note that B can be any matrix satisfying B * B^T = H where H is the Hessian, - but will agree with the one used in the other methods of this class. - - Args: - index: A tuple representing in the index of the entry in each slice that - is 1.0. Note that len(index) must be equal to the number of elements - of the 'hessian_factor_inner_shape' tensor minus one. - - Returns: - The vector right-multiplied by B^T. Will be of the same shape(s) as the - 'inputs' property. - """ - pass - - @abc.abstractproperty - def hessian_factor_inner_shape(self): - """The shape of the tensor returned by multiply_hessian_factor.""" - pass - - @abc.abstractproperty - def hessian_factor_inner_static_shape(self): - """Static version of hessian_factor_inner_shape.""" - pass - - -@six.add_metaclass(abc.ABCMeta) -class NegativeLogProbLoss(LossFunction): - """Abstract base class for loss functions that are negative log probs.""" - - def __init__(self, seed=None): - self._default_seed = seed - super(NegativeLogProbLoss, self).__init__() - - @property - def inputs(self): - return self.params - - @abc.abstractproperty - def params(self): - """Parameters to the underlying distribution.""" - pass - - @abc.abstractmethod - def multiply_fisher(self, vector): - """Right-multiply a vector by the Fisher. - - Args: - vector: The vector to multiply. Must be the same shape(s) as the - 'inputs' property. - - Returns: - The vector right-multiplied by the Fisher. Will be of the same shape(s) - as the 'inputs' property. - """ - pass - - @abc.abstractmethod - def multiply_fisher_factor(self, vector): - """Right-multiply a vector by a factor B of the Fisher. - - Here the 'Fisher' is the Fisher information matrix (i.e. expected outer- - product of gradients) with respect to the parameters of the underlying - probability distribution (whose log-prob defines the loss). Typically this - will be block-diagonal across different cases in the batch, since the - distribution is usually (but not always) conditionally iid across different - cases. - - Note that B can be any matrix satisfying B * B^T = F where F is the Fisher, - but will agree with the one used in the other methods of this class. - - Args: - vector: The vector to multiply. Must be of the shape given by the - 'fisher_factor_inner_shape' property. - - Returns: - The vector right-multiplied by B. Will be of the same shape(s) as the - 'inputs' property. - """ - pass - - @abc.abstractmethod - def multiply_fisher_factor_transpose(self, vector): - """Right-multiply a vector by the transpose of a factor B of the Fisher. - - Here the 'Fisher' is the Fisher information matrix (i.e. expected outer- - product of gradients) with respect to the parameters of the underlying - probability distribution (whose log-prob defines the loss). Typically this - will be block-diagonal across different cases in the batch, since the - distribution is usually (but not always) conditionally iid across different - cases. - - Note that B can be any matrix satisfying B * B^T = F where F is the Fisher, - but will agree with the one used in the other methods of this class. - - Args: - vector: The vector to multiply. Must be the same shape(s) as the - 'inputs' property. - - Returns: - The vector right-multiplied by B^T. Will be of the shape given by the - 'fisher_factor_inner_shape' property. - """ - pass - - @abc.abstractmethod - def multiply_fisher_factor_replicated_one_hot(self, index): - """Right-multiply a replicated-one-hot vector by a factor B of the Fisher. - - Here the 'Fisher' is the Fisher information matrix (i.e. expected outer- - product of gradients) with respect to the parameters of the underlying - probability distribution (whose log-prob defines the loss). Typically this - will be block-diagonal across different cases in the batch, since the - distribution is usually (but not always) conditionally iid across different - cases. - - A 'replicated-one-hot' vector means a tensor which, for each slice along the - batch dimension (assumed to be dimension 0), is 1.0 in the entry - corresponding to the given index and 0 elsewhere. - - Note that B can be any matrix satisfying B * B^T = H where H is the Fisher, - but will agree with the one used in the other methods of this class. - - Args: - index: A tuple representing in the index of the entry in each slice that - is 1.0. Note that len(index) must be equal to the number of elements - of the 'fisher_factor_inner_shape' tensor minus one. - - Returns: - The vector right-multiplied by B. Will be of the same shape(s) as the - 'inputs' property. - """ - pass - - @abc.abstractproperty - def fisher_factor_inner_shape(self): - """The shape of the tensor returned by multiply_fisher_factor.""" - pass - - @abc.abstractproperty - def fisher_factor_inner_static_shape(self): - """Static version of fisher_factor_inner_shape.""" - pass - - @abc.abstractmethod - def sample(self, seed): - """Sample 'targets' from the underlying distribution.""" - pass - - def evaluate_on_sample(self, seed=None): - """Evaluates the log probability on a random sample. - - Args: - seed: int or None. Random seed for this draw from the distribution. - - Returns: - Log probability of sampled targets, summed across examples. - """ - if seed is None: - seed = self._default_seed - # We treat the targets as "constant". It's only the inputs that get - # "back-propped" through. - return self._evaluate(array_ops.stop_gradient(self.sample(seed))) - - -# TODO(jamesmartens): should this just inherit from object to avoid "diamond" -# inheritance, or is there a better way? -class NaturalParamsNegativeLogProbLoss(NegativeLogProbLoss): - """Base class for neg log prob losses whose inputs are 'natural' parameters. - - Note that the Hessian and Fisher for natural parameters of exponential- - family models are the same, hence the purpose of this class. - See here: https://arxiv.org/abs/1412.1193 - - 'Natural parameters' are defined for exponential-family models. See for - example: https://en.wikipedia.org/wiki/Exponential_family - """ - - def multiply_hessian(self, vector): - return self.multiply_fisher(vector) - - def multiply_hessian_factor(self, vector): - return self.multiply_fisher_factor(vector) - - def multiply_hessian_factor_transpose(self, vector): - return self.multiply_fisher_factor_transpose(vector) - - def multiply_hessian_factor_replicated_one_hot(self, index): - return self.multiply_fisher_factor_replicated_one_hot(index) - - @property - def hessian_factor_inner_shape(self): - return self.fisher_factor_inner_shape - - @property - def hessian_factor_inner_static_shape(self): - return self.fisher_factor_inner_shape - - -class DistributionNegativeLogProbLoss(NegativeLogProbLoss): - """Base class for neg log prob losses that use the TF Distribution classes.""" - - def __init__(self, seed=None): - super(DistributionNegativeLogProbLoss, self).__init__(seed=seed) - - @abc.abstractproperty - def dist(self): - """The underlying tf.distributions.Distribution.""" - pass - - def _evaluate(self, targets): - return -math_ops.reduce_sum(self.dist.log_prob(targets)) - - def sample(self, seed): - return self.dist.sample(seed=seed) - - -class NormalMeanNegativeLogProbLoss(DistributionNegativeLogProbLoss, - NaturalParamsNegativeLogProbLoss): - """Neg log prob loss for a normal distribution parameterized by a mean vector. - - - Note that the covariance is treated as a constant 'var' times the identity. - Also note that the Fisher for such a normal distribution with respect the mean - parameter is given by: - - F = (1/var) * I - - See for example https://www.ii.pwr.edu.pl/~tomczak/PDF/[JMT]Fisher_inf.pdf. - """ - - def __init__(self, mean, var=0.5, targets=None, seed=None): - self._mean = mean - self._var = var - self._targets = targets - super(NormalMeanNegativeLogProbLoss, self).__init__(seed=seed) - - @property - def targets(self): - return self._targets - - @property - def dist(self): - return normal.Normal(loc=self._mean, scale=math_ops.sqrt(self._var)) - - @property - def params(self): - return self._mean - - def multiply_fisher(self, vector): - return (1. / self._var) * vector - - def multiply_fisher_factor(self, vector): - return self._var**-0.5 * vector - - def multiply_fisher_factor_transpose(self, vector): - return self.multiply_fisher_factor(vector) # it's symmetric in this case - - def multiply_fisher_factor_replicated_one_hot(self, index): - assert len(index) == 1, "Length of index was {}".format(len(index)) - ones_slice = array_ops.expand_dims( - array_ops.ones(array_ops.shape(self._mean)[:1], dtype=self._mean.dtype), - axis=-1) - output_slice = self._var**-0.5 * ones_slice - return insert_slice_in_zeros(output_slice, 1, int(self._mean.shape[1]), - index[0]) - - @property - def fisher_factor_inner_shape(self): - return array_ops.shape(self._mean) - - @property - def fisher_factor_inner_static_shape(self): - return self._mean.shape - - -class NormalMeanVarianceNegativeLogProbLoss(DistributionNegativeLogProbLoss): - """Negative log prob loss for a normal distribution with mean and variance. - - This class parameterizes a multivariate normal distribution with n independent - dimensions. Unlike `NormalMeanNegativeLogProbLoss`, this class does not - assume the variance is held constant. The Fisher Information for n = 1 - is given by, - - F = [[1 / variance, 0], - [ 0, 0.5 / variance^2]] - - where the parameters of the distribution are concatenated into a single - vector as [mean, variance]. For n > 1, the mean parameter vector is - concatenated with the variance parameter vector. - - See https://www.ii.pwr.edu.pl/~tomczak/PDF/[JMT]Fisher_inf.pdf for derivation. - """ - - def __init__(self, mean, variance, targets=None, seed=None): - assert len(mean.shape) == 2, "Expect 2D mean tensor." - assert len(variance.shape) == 2, "Expect 2D variance tensor." - self._mean = mean - self._variance = variance - self._targets = targets - super(NormalMeanVarianceNegativeLogProbLoss, self).__init__(seed=seed) - - @property - def targets(self): - return self._targets - - @property - def dist(self): - return normal.Normal(loc=self._mean, scale=math_ops.sqrt(self._variance)) - - @property - def params(self): - return self._mean, self._variance - - def _concat(self, mean, variance): - return array_ops.concat([mean, variance], axis=-1) - - def _split(self, params): - return array_ops.split(params, 2, axis=-1) - - @property - def _fisher_mean(self): - return 1. / self._variance - - @property - def _fisher_mean_factor(self): - return 1. / math_ops.sqrt(self._variance) - - @property - def _fisher_var(self): - return 1. / (2 * math_ops.square(self._variance)) - - @property - def _fisher_var_factor(self): - return 1. / (math_ops.sqrt(2.) * self._variance) - - def multiply_fisher(self, vecs): - mean_vec, var_vec = vecs - return (self._fisher_mean * mean_vec, self._fisher_var * var_vec) - - def multiply_fisher_factor(self, vecs): - mean_vec, var_vec = self._split(vecs) - return (self._fisher_mean_factor * mean_vec, - self._fisher_var_factor * var_vec) - - def multiply_fisher_factor_transpose(self, vecs): - mean_vec, var_vec = vecs - return self._concat(self._fisher_mean_factor * mean_vec, - self._fisher_var_factor * var_vec) - - def multiply_fisher_factor_replicated_one_hot(self, index): - assert len(index) == 1, "Length of index was {}".format(len(index)) - index = index[0] - - if index < int(self._mean.shape[-1]): - # Index corresponds to mean parameter. - mean_slice = self._fisher_mean_factor[:, index] - mean_slice = array_ops.expand_dims(mean_slice, axis=-1) - mean_output = insert_slice_in_zeros(mean_slice, 1, int( - self._mean.shape[1]), index) - var_output = array_ops.zeros_like(mean_output) - else: - index -= int(self._mean.shape[-1]) - # Index corresponds to variance parameter. - var_slice = self._fisher_var_factor[:, index] - var_slice = array_ops.expand_dims(var_slice, axis=-1) - var_output = insert_slice_in_zeros(var_slice, 1, - int(self._variance.shape[1]), index) - mean_output = array_ops.zeros_like(var_output) - - return mean_output, var_output - - @property - def fisher_factor_inner_shape(self): - return array_ops.concat( - [ - array_ops.shape(self._mean)[:-1], - 2 * array_ops.shape(self._mean)[-1:] - ], - axis=0) - - @property - def fisher_factor_inner_static_shape(self): - shape = self._mean.shape.as_list() - return tensor_shape.TensorShape(shape[-1:] + [2 * shape[-1]]) - - def multiply_hessian(self, vector): - raise NotImplementedError() - - def multiply_hessian_factor(self, vector): - raise NotImplementedError() - - def multiply_hessian_factor_transpose(self, vector): - raise NotImplementedError() - - def multiply_hessian_factor_replicated_one_hot(self, index): - raise NotImplementedError() - - @property - def hessian_factor_inner_shape(self): - raise NotImplementedError() - - @property - def hessian_factor_inner_static_shape(self): - raise NotImplementedError() - - -class CategoricalLogitsNegativeLogProbLoss(DistributionNegativeLogProbLoss, - NaturalParamsNegativeLogProbLoss): - """Neg log prob loss for a categorical distribution parameterized by logits. - - - Note that the Fisher (for a single case) of a categorical distribution, with - respect to the natural parameters (i.e. the logits), is given by: - - F = diag(p) - p*p^T - - where p = softmax(logits). F can be factorized as F = B * B^T where - - B = diag(q) - p*q^T - - where q is the entry-wise square root of p. This is easy to verify using the - fact that q^T*q = 1. - """ - - def __init__(self, logits, targets=None, seed=None): - """Instantiates a CategoricalLogitsNegativeLogProbLoss. - - Args: - logits: Tensor of shape [batch_size, output_size]. Parameters for - underlying distribution. - targets: None or Tensor of shape [output_size]. Each elements contains an - index in [0, output_size). - seed: int or None. Default random seed when sampling. - """ - self._logits = logits - self._targets = targets - super(CategoricalLogitsNegativeLogProbLoss, self).__init__(seed=seed) - - @property - def targets(self): - return self._targets - - @property - def dist(self): - return categorical.Categorical(logits=self._logits) - - @property - def _probs(self): - return self.dist.probs - - @property - def _sqrt_probs(self): - return math_ops.sqrt(self._probs) - - @property - def params(self): - return self._logits - - def multiply_fisher(self, vector): - probs = self._probs - return vector * probs - probs * math_ops.reduce_sum( - vector * probs, axis=-1, keepdims=True) - - def multiply_fisher_factor(self, vector): - probs = self._probs - sqrt_probs = self._sqrt_probs - return sqrt_probs * vector - probs * math_ops.reduce_sum( - sqrt_probs * vector, axis=-1, keepdims=True) - - def multiply_fisher_factor_transpose(self, vector): - probs = self._probs - sqrt_probs = self._sqrt_probs - return sqrt_probs * vector - sqrt_probs * math_ops.reduce_sum( - probs * vector, axis=-1, keepdims=True) - - def multiply_fisher_factor_replicated_one_hot(self, index): - assert len(index) == 1, "Length of index was {}".format(len(index)) - probs = self._probs - sqrt_probs = self._sqrt_probs - sqrt_probs_slice = array_ops.expand_dims(sqrt_probs[:, index[0]], -1) - padded_slice = insert_slice_in_zeros(sqrt_probs_slice, 1, - int(sqrt_probs.shape[1]), index[0]) - return padded_slice - probs * sqrt_probs_slice - - @property - def fisher_factor_inner_shape(self): - return array_ops.shape(self._logits) - - @property - def fisher_factor_inner_static_shape(self): - return self._logits.shape - - -class MultiBernoulliNegativeLogProbLoss(DistributionNegativeLogProbLoss, - NaturalParamsNegativeLogProbLoss): - """Neg log prob loss for multiple Bernoulli distributions param'd by logits. - - Represents N independent Bernoulli distributions where N = len(logits). Its - Fisher Information matrix is given by, - - F = diag(p * (1-p)) - p = sigmoid(logits) - - As F is diagonal with positive entries, its factor B is, - - B = diag(sqrt(p * (1-p))) - """ - - def __init__(self, logits, targets=None, seed=None): - self._logits = logits - self._targets = targets - super(MultiBernoulliNegativeLogProbLoss, self).__init__(seed=seed) - - @property - def targets(self): - return self._targets - - @property - def dist(self): - return bernoulli.Bernoulli(logits=self._logits) - - @property - def _probs(self): - return self.dist.probs - - @property - def params(self): - return self._logits - - def multiply_fisher(self, vector): - return self._probs * (1 - self._probs) * vector - - def multiply_fisher_factor(self, vector): - return math_ops.sqrt(self._probs * (1 - self._probs)) * vector - - def multiply_fisher_factor_transpose(self, vector): - return self.multiply_fisher_factor(vector) # it's symmetric in this case - - def multiply_fisher_factor_replicated_one_hot(self, index): - assert len(index) == 1, "Length of index was {}".format(len(index)) - probs_slice = array_ops.expand_dims(self._probs[:, index[0]], -1) - output_slice = math_ops.sqrt(probs_slice * (1 - probs_slice)) - return insert_slice_in_zeros(output_slice, 1, int(self._logits.shape[1]), - index[0]) - - @property - def fisher_factor_inner_shape(self): - return array_ops.shape(self._logits) - - @property - def fisher_factor_inner_static_shape(self): - return self._logits.shape - - -def insert_slice_in_zeros(slice_to_insert, dim, dim_size, position): - """Inserts slice into a larger tensor of zeros. - - Forms a new tensor which is the same shape as slice_to_insert, except that - the dimension given by 'dim' is expanded to the size given by 'dim_size'. - 'position' determines the position (index) at which to insert the slice within - that dimension. - - Assumes slice_to_insert.shape[dim] = 1. - - Args: - slice_to_insert: The slice to insert. - dim: The dimension which to expand with zeros. - dim_size: The new size of the 'dim' dimension. - position: The position of 'slice_to_insert' in the new tensor. - - Returns: - The new tensor. - - Raises: - ValueError: If the slice's shape at the given dim is not 1. - """ - slice_shape = slice_to_insert.shape - if slice_shape[dim] != 1: - raise ValueError("Expected slice_to_insert.shape to have {} dim of 1, but " - "was {}".format(dim, slice_to_insert.shape[dim])) - - before = [0] * int(len(slice_shape)) - after = before[:] - before[dim] = position - after[dim] = dim_size - position - 1 - - return array_ops.pad(slice_to_insert, list(zip(before, after))) - - -class OnehotCategoricalLogitsNegativeLogProbLoss( - CategoricalLogitsNegativeLogProbLoss): - """Neg log prob loss for a categorical distribution with onehot targets. - - Identical to CategoricalLogitsNegativeLogProbLoss except that the underlying - distribution is OneHotCategorical as opposed to Categorical. - """ - - @property - def dist(self): - return onehot_categorical.OneHotCategorical(logits=self._logits) diff --git a/tensorflow/contrib/kfac/python/ops/loss_functions_lib.py b/tensorflow/contrib/kfac/python/ops/loss_functions_lib.py deleted file mode 100644 index 4279cb2792..0000000000 --- a/tensorflow/contrib/kfac/python/ops/loss_functions_lib.py +++ /dev/null @@ -1,39 +0,0 @@ -# Copyright 2017 The TensorFlow Authors. All Rights Reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -# ============================================================================== -"""Loss functions to be used by LayerCollection.""" - -from __future__ import absolute_import -from __future__ import division -from __future__ import print_function - -# pylint: disable=unused-import,line-too-long,wildcard-import -from tensorflow.contrib.kfac.python.ops.loss_functions import * -from tensorflow.python.util.all_util import remove_undocumented -# pylint: enable=unused-import,line-too-long,wildcard-import - -_allowed_symbols = [ - "LossFunction", - "NegativeLogProbLoss", - "NaturalParamsNegativeLogProbLoss", - "DistributionNegativeLogProbLoss", - "NormalMeanNegativeLogProbLoss", - "NormalMeanVarianceNegativeLogProbLoss", - "CategoricalLogitsNegativeLogProbLoss", - "OnehotCategoricalLogitsNegativeLogProbLoss", - "MultiBernoulliNegativeLogProbLoss", - "insert_slice_in_zeros", -] - -remove_undocumented(__name__, allowed_exception_list=_allowed_symbols) diff --git a/tensorflow/contrib/kfac/python/ops/op_queue.py b/tensorflow/contrib/kfac/python/ops/op_queue.py deleted file mode 100644 index b6d9d37a31..0000000000 --- a/tensorflow/contrib/kfac/python/ops/op_queue.py +++ /dev/null @@ -1,69 +0,0 @@ -# Copyright 2017 The TensorFlow Authors. All Rights Reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -# ============================================================================== -"""Helper for choosing which op to run next in a distributed setting.""" - -from __future__ import absolute_import -from __future__ import division -from __future__ import print_function - -from tensorflow.python.data.ops import dataset_ops -from tensorflow.python.framework import ops as tf_ops - - -class OpQueue(object): - """Class for choosing which Op to run next. - - Constructs an infinitely repeating sequence of Ops in shuffled order. - - In K-FAC, this can be used to distribute inverse update operations among - workers. - """ - - def __init__(self, ops, seed=None): - """Initializes an OpQueue. - - Args: - ops: list of TensorFlow Ops. Ops to be selected from. All workers must - initialize with the same set of ops. - seed: int or None. Random seed used when shuffling order of ops. - """ - self._ops_by_name = {op.name: op for op in ops} - - # Construct a (shuffled) Dataset with Op names. - op_names = tf_ops.convert_to_tensor(list(sorted(op.name for op in ops))) - op_names_dataset = (dataset_ops.Dataset.from_tensor_slices(op_names) - .shuffle(len(ops), seed=seed).repeat()) - self._next_op_name = op_names_dataset.make_one_shot_iterator().get_next() - - @property - def ops(self): - """Ops this OpQueue can return in next_op().""" - return self._ops_by_name.values() - - def next_op(self, sess): - """Chooses which op to run next. - - Note: This call will make a call to sess.run(). - - Args: - sess: tf.Session. - - Returns: - Next Op chosen from 'ops'. - """ - # In Python 3, type(next_op_name) == bytes. Calling bytes.decode('ascii') - # returns a str. - next_op_name = sess.run(self._next_op_name).decode('ascii') - return self._ops_by_name[next_op_name] diff --git a/tensorflow/contrib/kfac/python/ops/op_queue_lib.py b/tensorflow/contrib/kfac/python/ops/op_queue_lib.py deleted file mode 100644 index 09c9a4ab33..0000000000 --- a/tensorflow/contrib/kfac/python/ops/op_queue_lib.py +++ /dev/null @@ -1,30 +0,0 @@ -# Copyright 2017 The TensorFlow Authors. All Rights Reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -# ============================================================================== -"""Helper for choosing which op to run next in a distributed setting.""" - -from __future__ import absolute_import -from __future__ import division -from __future__ import print_function - -# pylint: disable=unused-import,line-too-long,wildcard-import -from tensorflow.contrib.kfac.python.ops.op_queue import * -from tensorflow.python.util.all_util import remove_undocumented -# pylint: enable=unused-import,line-too-long,wildcard-import - -_allowed_symbols = [ - 'OpQueue', -] - -remove_undocumented(__name__, allowed_exception_list=_allowed_symbols) diff --git a/tensorflow/contrib/kfac/python/ops/optimizer.py b/tensorflow/contrib/kfac/python/ops/optimizer.py deleted file mode 100644 index 38605259b5..0000000000 --- a/tensorflow/contrib/kfac/python/ops/optimizer.py +++ /dev/null @@ -1,727 +0,0 @@ -# Copyright 2017 The TensorFlow Authors. All Rights Reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -# ============================================================================== -"""The KFAC optimizer.""" - -from __future__ import absolute_import -from __future__ import division -from __future__ import print_function - -import warnings - -# pylint disable=long-line -from tensorflow.contrib.kfac.python.ops import curvature_matrix_vector_products as cmvp -from tensorflow.contrib.kfac.python.ops import estimator as est -# pylint enable=long-line - -from tensorflow.python.framework import dtypes -from tensorflow.python.framework import ops -from tensorflow.python.ops import array_ops -from tensorflow.python.ops import control_flow_ops -from tensorflow.python.ops import linalg_ops -from tensorflow.python.ops import math_ops -from tensorflow.python.ops import state_ops -from tensorflow.python.ops import variable_scope -from tensorflow.python.ops import variables as tf_variables -from tensorflow.python.training import gradient_descent - - -class KfacOptimizer(gradient_descent.GradientDescentOptimizer): - """The KFAC Optimizer (https://arxiv.org/abs/1503.05671).""" - - def __init__(self, - learning_rate, - cov_ema_decay, - damping, - layer_collection, - var_list=None, - momentum=0.9, - momentum_type="regular", - norm_constraint=None, - name="KFAC", - estimation_mode="gradients", - colocate_gradients_with_ops=True, - batch_size=None, - placement_strategy=None, - **kwargs): - """Initializes the KFAC optimizer with the given settings. - - Args: - learning_rate: The base learning rate for the optimizer. Should probably - be set to 1.0 when using momentum_type = 'qmodel', but can still be - set lowered if desired (effectively lowering the trust in the - quadratic model.) - cov_ema_decay: The decay factor used when calculating the covariance - estimate moving averages. - damping: The damping factor used to stabilize training due to errors in - the local approximation with the Fisher information matrix, and to - regularize the update direction by making it closer to the gradient. - If damping is adapted during training then this value is used for - initializing damping variable. - (Higher damping means the update looks more like a standard gradient - update - see Tikhonov regularization.) - layer_collection: The layer collection object, which holds the fisher - blocks, Kronecker factors, and losses associated with the - graph. The layer_collection cannot be modified after KfacOptimizer's - initialization. - var_list: Optional list or tuple of variables to train. Defaults to the - list of variables collected in the graph under the key - `GraphKeys.TRAINABLE_VARIABLES`. - momentum: The momentum decay constant to use. Only applies when - momentum_type is 'regular' or 'adam'. (Default: 0.9) - momentum_type: The type of momentum to use in this optimizer, one of - 'regular', 'adam', or 'qmodel'. (Default: 'regular') - norm_constraint: float or Tensor. If specified, the update is scaled down - so that its approximate squared Fisher norm v^T F v is at most the - specified value. May only be used with momentum type 'regular'. - (Default: None) - name: The name for this optimizer. (Default: 'KFAC') - estimation_mode: The type of estimator to use for the Fishers. Can be - 'gradients', 'empirical', 'curvature_propagation', or 'exact'. - (Default: 'gradients'). See the doc-string for FisherEstimator for - more a more detailed description of these options. - colocate_gradients_with_ops: Whether we should request gradients we - compute in the estimator be colocated with their respective ops. - (Default: True) - batch_size: The size of the mini-batch. Only needed when momentum_type - == 'qmodel' or when automatic adjustment is used. (Default: None) - placement_strategy: string, Device placement strategy used when creating - covariance variables, covariance ops, and inverse ops. - (Default: `None`) - **kwargs: Arguments to be passed to specific placement - strategy mixin. Check `placement.RoundRobinPlacementMixin` for example. - - Raises: - ValueError: If the momentum type is unsupported. - ValueError: If clipping is used with momentum type other than 'regular'. - ValueError: If no losses have been registered with layer_collection. - ValueError: If momentum is non-zero and momentum_type is not 'regular' - or 'adam'. - """ - warnings.warn( - "third_party.tensorflow.contrib.kfac is deprecated." - "This will be removed on 15-07-2018. Check README for further details.", - DeprecationWarning) - # Parameters to be passed to the Fisher estimator: - self._variables = var_list or tf_variables.trainable_variables - self._cov_ema_decay = cov_ema_decay - self._layers = layer_collection - self._estimation_mode = estimation_mode - self._colocate_gradients_with_ops = colocate_gradients_with_ops - - # The below parameters are required only if damping needs to be adapted. - # These parameters can be set by calling - # set_damping_adaptation_params() explicitly. - self._damping_adaptation_decay = 0.95 - self._damping_adaptation_interval = 5 - # Check section 6.5 KFAC paper. omega(1) = pow(damping decay, interval) - self._omega = ( - self._damping_adaptation_decay**self._damping_adaptation_interval) - self._adapt_damping = False - self._min_damping = 1e-5 - self._prev_train_batch = None - self._is_chief = False - self._loss_fn = None - self._damping_constant = damping - self._damping = None - self._rho = None - self._prev_loss = None - self._q_model_change = None - self._update_damping_op = None - - momentum_type = momentum_type.lower() - legal_momentum_types = ["regular", "adam", "qmodel"] - - if momentum_type not in legal_momentum_types: - raise ValueError("Unsupported momentum type {}. Must be one of {}." - .format(momentum_type, legal_momentum_types)) - if momentum_type != "regular" and norm_constraint is not None: - raise ValueError("Update clipping is only supported with momentum " - "type 'regular'.") - if momentum_type not in ["regular", "adam"] and momentum != 0: - raise ValueError("Momentum must be unspecified if using a momentum_type " - "other than 'regular' or 'adam'.") - - # Extra parameters of the optimizer - self._momentum = momentum - self._momentum_type = momentum_type - self._norm_constraint = norm_constraint - self._batch_size = batch_size - self._placement_strategy = placement_strategy - - with variable_scope.variable_scope(name): - self._fisher_est = est.make_fisher_estimator( - placement_strategy=placement_strategy, - variables=self._variables, - cov_ema_decay=self._cov_ema_decay, - damping=self.damping, - layer_collection=self._layers, - exps=(-1,), - estimation_mode=self._estimation_mode, - colocate_gradients_with_ops=self._colocate_gradients_with_ops, - **kwargs) - - super(KfacOptimizer, self).__init__(learning_rate, name=name) - - def set_damping_adaptation_params(self, - is_chief, - prev_train_batch, - loss_fn, - min_damping=1e-5, - damping_adaptation_decay=0.99, - damping_adaptation_interval=5): - """Sets parameters required to adapt damping during training. - - When called, enables damping adaptation according to the Levenberg-Marquardt - style rule described in Section 6.5 of "Optimizing Neural Networks with - Kronecker-factored Approximate Curvature". - - Note that this function creates Tensorflow variables which store a few - scalars and are accessed by the ops which update the damping (as part - of the training op returned by the minimize() method). - - Args: - is_chief: `Boolean`, `True` if the worker is chief. - prev_train_batch: Training data used to minimize loss in the previous - step. This will be used to evaluate loss by calling - `loss_fn(prev_train_batch)`. - loss_fn: `function` that takes as input training data tensor and returns - a scalar loss. - min_damping: `float`(Optional), Minimum value the damping parameter - can take. Default value 1e-5. - damping_adaptation_decay: `float`(Optional), The `damping` parameter is - multiplied by the `damping_adaptation_decay` every - `damping_adaptation_interval` number of iterations. Default value 0.99. - damping_adaptation_interval: `int`(Optional), Number of steps in between - updating the `damping` parameter. Default value 5. - - Raises: - ValueError: If `set_damping_adaptation_params` is already called and the - the `adapt_damping` is `True`. - """ - if self._adapt_damping: - raise ValueError("Damping adaptation parameters already set.") - - with variable_scope.variable_scope(self.get_name()): - self._adapt_damping = True - self._is_chief = is_chief - self._prev_train_batch = prev_train_batch - self._loss_fn = loss_fn - self._damping_adaptation_decay = damping_adaptation_decay - self._damping_adaptation_interval = damping_adaptation_interval - self._omega = ( - self._damping_adaptation_decay**self._damping_adaptation_interval) - self._min_damping = min_damping - - self._rho = variable_scope.get_variable( - "rho", shape=(), dtype=dtypes.float32, trainable=False) # LM ratio. - self._prev_loss = variable_scope.get_variable( - "prev_loss", shape=(), dtype=dtypes.float32, trainable=False) - self._q_model_change = variable_scope.get_variable( - "q_model_change", shape=(), dtype=dtypes.float32, trainable=False) - self._damping = variable_scope.get_variable( - "damping", initializer=self._damping_constant, trainable=False) - - @property - def variables(self): - return self._fisher_est.variables - - @property - def damping(self): - if self._damping: - return self._damping - else: - return self._damping_constant - - @property - def damping_adaptation_interval(self): - return self._damping_adaptation_interval - - def make_vars_and_create_op_thunks(self): - """Make vars and create op thunks. - - Returns: - cov_update_thunks: List of cov update thunks. Corresponds one-to-one with - the list of factors given by the "factors" property. - inv_update_thunks: List of inv update thunks. Corresponds one-to-one with - the list of factors given by the "factors" property. - """ - scope = self.get_name() + "/" + self._fisher_est.name - return self._fisher_est.make_vars_and_create_op_thunks(scope=scope) - - def create_ops_and_vars_thunks(self): - """Create thunks that make the ops and vars on demand. - - This function returns 4 lists of thunks: cov_variable_thunks, - cov_update_thunks, inv_variable_thunks, and inv_update_thunks. - - The length of each list is the number of factors and the i-th element of - each list corresponds to the i-th factor (given by the "factors" property). - - Note that the execution of these thunks must happen in a certain - partial order. The i-th element of cov_variable_thunks must execute - before the i-th element of cov_update_thunks (and also the i-th element - of inv_update_thunks). Similarly, the i-th element of inv_variable_thunks - must execute before the i-th element of inv_update_thunks. - - TL;DR (oversimplified): Execute the thunks according to the order that - they are returned. - - Returns: - cov_variable_thunks: A list of thunks that make the cov variables. - cov_update_thunks: A list of thunks that make the cov update ops. - inv_variable_thunks: A list of thunks that make the inv variables. - inv_update_thunks: A list of thunks that make the inv update ops. - """ - scope = self.get_name() + "/" + self._fisher_est.name - return self._fisher_est.create_ops_and_vars_thunks(scope=scope) - - def minimize(self, *args, **kwargs): - # Should this variable scope encompass everything below? Or will the super- - # class make another copy of the same name scope? - with variable_scope.variable_scope(self.get_name()): - kwargs["var_list"] = kwargs.get("var_list") or self.variables - if set(kwargs["var_list"]) != set(self.variables): - raise ValueError("var_list doesn't match with set of Fisher-estimating " - "variables.") - if self._adapt_damping and self._is_chief: - global_step = kwargs.get("global_step", None) - if not global_step: - raise KeyError("global_step needs to be passed to optimizer.minimize " - "if damping parameter is adapted.") - update_damping_op = self._update_damping(self._prev_train_batch, - global_step) - with ops.control_dependencies([update_damping_op]): - loss = args[0] - loss_assign_op = state_ops.assign(self._prev_loss, loss) - train_op = super(KfacOptimizer, self).minimize(*args, **kwargs) - return control_flow_ops.group(loss_assign_op, train_op) - else: - return super(KfacOptimizer, self).minimize(*args, **kwargs) - - def compute_gradients(self, *args, **kwargs): - # args[1] could be our var_list - if len(args) > 1: - var_list = args[1] - else: - kwargs["var_list"] = kwargs.get("var_list") or self.variables - var_list = kwargs["var_list"] - - if set(var_list) != set(self.variables): - raise ValueError("var_list doesn't match with set of Fisher-estimating " - "variables.") - return super(KfacOptimizer, self).compute_gradients(*args, **kwargs) - - def apply_gradients(self, grads_and_vars, *args, **kwargs): - """Applies gradients to variables. - - Args: - grads_and_vars: List of (gradient, variable) pairs. - *args: Additional arguments for super.apply_gradients. - **kwargs: Additional keyword arguments for super.apply_gradients. - - Returns: - An `Operation` that applies the specified gradients. - """ - # In Python 3, grads_and_vars can be a zip() object which can only be - # iterated over once. By converting it to a list, we ensure that it can be - # iterated over more than once. - grads_and_vars = list(grads_and_vars) - - # Compute step. - steps_and_vars = self._compute_update_steps(grads_and_vars) - - # Update trainable variables with this step. - return super(KfacOptimizer, self).apply_gradients(steps_and_vars, *args, - **kwargs) - - def _squared_fisher_norm(self, grads_and_vars, precon_grads_and_vars): - """Computes the squared (approximate) Fisher norm of the updates. - - This is defined as v^T F v, where F is the approximate Fisher matrix - as computed by the estimator, and v = F^{-1} g, where g is the gradient. - This is computed efficiently as v^T g. - - Args: - grads_and_vars: List of (gradient, variable) pairs. - precon_grads_and_vars: List of (preconditioned gradient, variable) pairs. - Must be the result of calling `self._fisher_est.multiply_inverse` - on `grads_and_vars`. - - Returns: - Scalar representing the squared norm. - - Raises: - ValueError: if the two list arguments do not contain the same variables, - in the same order. - """ - for (_, gvar), (_, pgvar) in zip(grads_and_vars, precon_grads_and_vars): - if gvar is not pgvar: - raise ValueError("The variables referenced by the two arguments " - "must match.") - terms = [ - math_ops.reduce_sum(grad * pgrad) - for (grad, _), (pgrad, _) in zip(grads_and_vars, precon_grads_and_vars) - ] - return math_ops.reduce_sum(terms) - - def _update_clip_coeff(self, grads_and_vars, precon_grads_and_vars): - """Computes the scale factor for the update to satisfy the norm constraint. - - Defined as min(1, sqrt(c / r^T F r)), where c is the norm constraint, - F is the approximate Fisher matrix, and r is the update vector, i.e. - -alpha * v, where alpha is the learning rate, and v is the preconditioned - gradient. - - This is based on Section 5 of Ba et al., Distributed Second-Order - Optimization using Kronecker-Factored Approximations. Note that they - absorb the learning rate alpha (which they denote eta_max) into the formula - for the coefficient, while in our implementation, the rescaling is done - before multiplying by alpha. Hence, our formula differs from theirs by a - factor of alpha. - - Args: - grads_and_vars: List of (gradient, variable) pairs. - precon_grads_and_vars: List of (preconditioned gradient, variable) pairs. - Must be the result of calling `self._fisher_est.multiply_inverse` - on `grads_and_vars`. - - Returns: - Scalar representing the coefficient which should be applied to the - preconditioned gradients to satisfy the norm constraint. - """ - sq_norm_grad = self._squared_fisher_norm(grads_and_vars, - precon_grads_and_vars) - sq_norm_up = sq_norm_grad * self._learning_rate**2 - return math_ops.minimum(1., - math_ops.sqrt(self._norm_constraint / sq_norm_up)) - - def _clip_updates(self, grads_and_vars, precon_grads_and_vars): - """Rescales the preconditioned gradients to satisfy the norm constraint. - - Rescales the preconditioned gradients such that the resulting update r - (after multiplying by the learning rate) will satisfy the norm constraint. - This constraint is that r^T F r <= C, where F is the approximate Fisher - matrix, and C is the norm_constraint attribute. See Section 5 of - Ba et al., Distributed Second-Order Optimization using Kronecker-Factored - Approximations. - - Args: - grads_and_vars: List of (gradient, variable) pairs. - precon_grads_and_vars: List of (preconditioned gradient, variable) pairs. - Must be the result of calling `self._fisher_est.multiply_inverse` - on `grads_and_vars`. - - Returns: - List of (rescaled preconditioned gradient, variable) pairs. - """ - coeff = self._update_clip_coeff(grads_and_vars, precon_grads_and_vars) - return [(pgrad * coeff, var) for pgrad, var in precon_grads_and_vars] - - def _compute_prev_updates(self, variables): - """Computes previous updates as negative velocities scaled by learning rate. - - Args: - variables: List of variables in the graph that the update will be - applied to. - - Returns: - List of previous updates applied to the `variables`. - """ - return list( - -1 * self._learning_rate * self._zeros_slot(var, "velocity", self._name) - for var in variables) - - def _compute_qmodel_hyperparams(self, precon_grads, prev_updates, grads, - variables): - """Compute optimal update hyperparameters from the quadratic model. - - More specifically, if L is the loss we minimize a quadratic approximation - of L(theta + d) which we denote by qmodel(d) with - d = alpha*precon_grad + mu*prev_update with respect to alpha and mu, where - - qmodel(d) = (1/2) * d^T * B * d + grad^T*d + L(theta) . - - Unlike in the KL clipping approach we use the non-approximated quadratic - model where the curvature matrix C is the true Fisher on the current - mini-batch (computed without any approximations beyond mini-batch sampling), - with the usual Tikhonov damping/regularization applied, - - C = F + damping * I - - See Section 7 of https://arxiv.org/abs/1503.05671 for a derivation of - the formula. See Appendix C for a discussion of the trick of using - a factorized Fisher matrix to more efficiently compute the required - vector-matrix-vector products. - - Note that the elements of all 4 lists passed to this function must - be in correspondence with each other. - - Args: - precon_grads: List of preconditioned gradients. - prev_updates: List of updates computed at the previous iteration. - grads: List of gradients. - variables: List of variables in the graph that the update will be - applied to. (Note that this function doesn't actually apply the - update.) - - Returns: - (alpha, mu, qmodel_change), where alpha and mu are chosen to optimize the - quadratic model, and - qmodel_change = qmodel(alpha*precon_grad + mu*prev_update) - qmodel(0) - = qmodel(alpha*precon_grad + mu*prev_update) - L(theta). - """ - - cmvpc = cmvp.CurvatureMatrixVectorProductComputer(self._layers.losses, - variables) - - # compute the matrix-vector products with the transposed Fisher factor - fft_precon_grads = cmvpc.multiply_fisher_factor_transpose(precon_grads) - fft_prev_updates = cmvpc.multiply_fisher_factor_transpose(prev_updates) - batch_size = math_ops.cast( - self._batch_size, dtype=fft_precon_grads[0].dtype) - - # compute the entries of the 2x2 matrix - m_11 = ( - _inner_product_list(fft_precon_grads, fft_precon_grads) / batch_size + - self.damping * _inner_product_list(precon_grads, precon_grads)) - - m_21 = ( - _inner_product_list(fft_prev_updates, fft_precon_grads) / batch_size + - self.damping * _inner_product_list(prev_updates, precon_grads)) - - m_22 = ( - _inner_product_list(fft_prev_updates, fft_prev_updates) / batch_size + - self.damping * _inner_product_list(prev_updates, prev_updates)) - - def non_zero_prevupd_case(): - r"""Computes optimal (alpha, mu) given non-zero previous update. - - We solve the full 2x2 linear system. See Martens & Grosse (2015), - Section 7, definition of $\alpha^*$ and $\mu^*$. - - Returns: - (alpha, mu, qmodel_change), where alpha and mu are chosen to optimize - the quadratic model, and - qmodel_change = qmodel(alpha*precon_grad + mu*prev_update) - qmodel(0). - """ - m = ops.convert_to_tensor([[m_11, m_21], [m_21, m_22]]) - - c = ops.convert_to_tensor([[_inner_product_list(grads, precon_grads)], - [_inner_product_list(grads, prev_updates)]]) - - sol = -1. * _two_by_two_solve(m, c) - alpha = sol[0] - mu = sol[1] - qmodel_change = 0.5 * math_ops.reduce_sum(sol * c) - - return alpha, mu, qmodel_change - - def zero_prevupd_case(): - r"""Computes optimal (alpha, mu) given all-zero previous update. - - The linear system reduces to 1x1. See Martens & Grosse (2015), - Section 6.4, definition of $\alpha^*$. - - Returns: - (alpha, 0.0, qmodel_change), where alpha is chosen to optimize the - quadratic model, and - qmodel_change = qmodel(alpha*precon_grad) - qmodel(0) - """ - m = m_11 - c = _inner_product_list(grads, precon_grads) - - alpha = -c / m - mu = 0.0 - qmodel_change = 0.5 * alpha * c - - return alpha, mu, qmodel_change - - return control_flow_ops.cond( - math_ops.equal(m_22, 0.0), zero_prevupd_case, non_zero_prevupd_case) - - def _assign_q_model_change(self, q_model_change): - """Assigns `q_model_change` to `self._q_model_change` if damping is adapted. - - Note only the chief worker does the assignment. - - Args: - q_model_change: Scalar tensor of type `float32`. - - Returns: - If `adapt_damping` is `True` then returns an assign op, Otherwise returns - a no_op(). - """ - if self._adapt_damping and self._is_chief: - q_model_assign_op = state_ops.assign(self._q_model_change, q_model_change) - else: - q_model_assign_op = control_flow_ops.no_op() - return q_model_assign_op - - def _compute_qmodel_hyperparams_wrapper(self, grads_and_vars, - precon_grads_and_vars): - """Wrapper function for `self._compute_qmodel_hyperparams`. - - Constructs a list of preconditioned gradients and variables. Also creates a - op to assign the computed q model change to `self._q_model_change`. - - Args: - grads_and_vars: List of (gradient, variable) pairs. - precon_grads_and_vars: List of (preconditioned gradients, variable) - pairs. - - Returns: - (alpha, mu, q_model_assign_op), where alpha and mu are chosen to optimize - the quadratic model, `q_model_assign_op` assigns the computed q model - change to `self._q_model_change`. - """ - precon_grads = list( - precon_grad for (precon_grad, _) in precon_grads_and_vars) - grads = list(grad for (grad, _) in grads_and_vars) - variables = list(var for (_, var) in grads_and_vars) - prev_updates = self._compute_prev_updates(variables) - # Compute optimal velocity update parameters according to quadratic model - alpha, mu, q_model_change = self._compute_qmodel_hyperparams( - precon_grads, prev_updates, grads, variables) - - return alpha, mu, self._assign_q_model_change(q_model_change) - - def _compute_update_steps(self, grads_and_vars): - """Computes the update steps for the variables given the gradients. - - Args: - grads_and_vars: List of (gradient, variable) pairs. - - Returns: - A list of tuple (assign_op ,var) where `assign_op` assigns the update - steps to `var`. - """ - - if self._momentum_type == "regular": - # Compute "preconditioned" gradient. - precon_grads_and_vars = self._fisher_est.multiply_inverse(grads_and_vars) - - # Apply "KL clipping" if asked for. - if self._norm_constraint is not None: - precon_grads_and_vars = self._clip_updates(grads_and_vars, - precon_grads_and_vars) - - # Update the velocity with this and return it as the step. - if self._adapt_damping and self._is_chief: - _, _, q_model_assign_op = self._compute_qmodel_hyperparams_wrapper( - grads_and_vars, precon_grads_and_vars) - with ops.control_dependencies([q_model_assign_op]): - return self._update_velocities(precon_grads_and_vars, self._momentum) - else: - return self._update_velocities(precon_grads_and_vars, self._momentum) - elif self._momentum_type == "adam": - # Update velocity. - velocities_and_vars = self._update_velocities(grads_and_vars, - self._momentum) - # Return "preconditioned" velocity vector as the step. - return self._fisher_est.multiply_inverse(velocities_and_vars) - - elif self._momentum_type == "qmodel": - # Compute "preconditioned" gradient. - precon_grads_and_vars = self._fisher_est.multiply_inverse(grads_and_vars) - - # Compute optimal velocity update parameters according to quadratic model - alpha, mu, q_model_assign_op = self._compute_qmodel_hyperparams_wrapper( - grads_and_vars, precon_grads_and_vars) - - with ops.control_dependencies([q_model_assign_op]): - return self._update_velocities( - precon_grads_and_vars, mu, vec_coeff=-alpha) - - def _update_velocities(self, vecs_and_vars, decay, vec_coeff=1.0): - """Updates the velocities of the variables with the given vectors. - - Args: - vecs_and_vars: List of (vector, variable) pairs. - decay: How much to decay the old velocity by. This is often referred to - as the 'momentum constant'. - vec_coeff: Coefficient to apply to the vectors before adding them to the - velocity. - - Returns: - A list of (velocity, var) indicating the new velocity for each var. - """ - - def _update_velocity(vec, var): - velocity = self._zeros_slot(var, "velocity", self._name) - with ops.colocate_with(velocity): - # NOTE(mattjj): read/modify/write race condition not suitable for async. - - # Compute the new velocity for this variable. - new_velocity = decay * velocity + vec_coeff * vec - - # Save the updated velocity. - return (array_ops.identity(velocity.assign(new_velocity)), var) - - # Go through variable and update its associated part of the velocity vector. - return [_update_velocity(vec, var) for vec, var in vecs_and_vars] - - def _update_damping(self, prev_batch, global_step): - """Adapts damping parameter. Check KFAC (Section 6.5) for the details. - - The damping parameter is updated according to the Levenberg-Marquardt rule - every `self._damping_adaptation_interval` iterations. - - Args: - prev_batch: Tensor or tuple of tensors which can be passed to - `self._loss_fn` to evaluate loss. - global_step: `Variable` which keeps track of number of times the training - variables have been updated. - Returns: - A `tf.cond` op which updates the damping parameter. - """ - def compute_damping(): - """"Adapts damping parameter based on "reduction ratio". - - Reduction ratio captures how closely the quadratic approximation to the - loss function approximates the actual loss within a trust region. The - damping update tries to make the damping as small as possible while - maintaining the property that the quadratic model remains a good local - approximation to the loss function. - - Returns: - An Op to assign newly computed damping value to `self._damping`. - """ - prev_batch_loss = self._loss_fn(prev_batch) - with ops.control_dependencies([prev_batch_loss]): - rho_assign = self._rho.assign( - (prev_batch_loss - self._prev_loss) / self._q_model_change) - with ops.control_dependencies([rho_assign]): - new_damping = control_flow_ops.case( - [(self._rho < 0.25, lambda: self.damping / self._omega), - (self._rho > 0.75, lambda: self.damping * self._omega)], - lambda: self.damping) - with ops.control_dependencies([new_damping]): - new_damping_min = math_ops.maximum(new_damping, self._min_damping) - return control_flow_ops.group(self._damping.assign(new_damping_min)) - - return control_flow_ops.cond( - math_ops.equal( - math_ops.mod(global_step + 1, self._damping_adaptation_interval), - 0), compute_damping, control_flow_ops.no_op) - - -def _inner_product_list(list1, list2): - return math_ops.add_n( - [math_ops.reduce_sum(elt1 * elt2) for elt1, elt2 in zip(list1, list2)]) - - -def _two_by_two_solve(m, c): - # it might be better just to crank out the exact formula for 2x2 inverses - return math_ops.matmul(linalg_ops.matrix_inverse(m), c) diff --git a/tensorflow/contrib/kfac/python/ops/optimizer_lib.py b/tensorflow/contrib/kfac/python/ops/optimizer_lib.py deleted file mode 100644 index 87d1866e06..0000000000 --- a/tensorflow/contrib/kfac/python/ops/optimizer_lib.py +++ /dev/null @@ -1,30 +0,0 @@ -# Copyright 2017 The TensorFlow Authors. All Rights Reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -# ============================================================================== -"""The KFAC optimizer.""" - -from __future__ import absolute_import -from __future__ import division -from __future__ import print_function - -# pylint: disable=unused-import,line-too-long,wildcard-import -from tensorflow.contrib.kfac.python.ops.optimizer import * -from tensorflow.python.util.all_util import remove_undocumented -# pylint: enable=unused-import,line-too-long,wildcard-import - -_allowed_symbols = [ - "KfacOptimizer", -] - -remove_undocumented(__name__, allowed_exception_list=_allowed_symbols) diff --git a/tensorflow/contrib/kfac/python/ops/placement.py b/tensorflow/contrib/kfac/python/ops/placement.py deleted file mode 100644 index c4454325ae..0000000000 --- a/tensorflow/contrib/kfac/python/ops/placement.py +++ /dev/null @@ -1,114 +0,0 @@ -# Copyright 2017 The TensorFlow Authors. All Rights Reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -# ============================================================================== -"""Implements placement strategies for cov and inv ops, cov variables.""" - -from __future__ import absolute_import -from __future__ import division -from __future__ import print_function - -import itertools - -from tensorflow.python.framework import ops as tf_ops - - -def _make_thunk_on_device(func, device): - def thunk(): - with tf_ops.device(device): - return func() - return thunk - - -class RoundRobinPlacementMixin(object): - """Implements round robin placement strategy for ops and variables.""" - - def __init__(self, cov_devices=None, inv_devices=None, **kwargs): - """Initializes the RoundRobinPlacementMixin class. - - Args: - cov_devices: Iterable of device strings (e.g. '/gpu:0'). Covariance - computations will be placed on these devices in a round-robin fashion. - Can be None, which means that no devices are specified. - inv_devices: Iterable of device strings (e.g. '/gpu:0'). Inversion - computations will be placed on these devices in a round-robin fashion. - Can be None, which means that no devices are specified. - **kwargs: Need something here? - - """ - super(RoundRobinPlacementMixin, self).__init__(**kwargs) - self._cov_devices = cov_devices - self._inv_devices = inv_devices - - def make_vars_and_create_op_thunks(self, scope=None): - """Make vars and create op thunks w/ a round-robin device placement start. - - For each factor, all of that factor's cov variables and their associated - update ops will be placed on a particular device. A new device is chosen - for each factor by cycling through list of devices in the - `self._cov_devices` attribute. If `self._cov_devices` is `Non`e then no - explicit device placement occurs. - - An analogous strategy is followed for inverse update ops, with the list of - devices being given by the `self._inv_devices` attribute. - - Inverse variables on the other hand are not placed on any specific device - (they will just use the current the device placement context, whatever - that happens to be). The idea is that the inverse variable belong where - they will be accessed most often, which is the device that actually applies - the preconditioner to the gradient. The user will be responsible for setting - the device context for this. - - Args: - scope: A string or None. If None it will be set to the name of this - estimator (given by the name property). All variables will be created, - and all thunks will execute, inside of a variable scope of the given - name. (Default: None) - - Returns: - cov_update_thunks: List of cov update thunks. Corresponds one-to-one with - the list of factors given by the "factors" property. - inv_update_thunks: List of inv update thunks. Corresponds one-to-one with - the list of factors given by the "factors" property. - """ - # Note: `create_ops_and_vars_thunks` is implemented in `FisherEstimator`. - (cov_variable_thunks_raw, cov_update_thunks_raw, inv_variable_thunks_raw, - inv_update_thunks_raw) = self.create_ops_and_vars_thunks(scope=scope) - - if self._cov_devices: - cov_update_thunks = [] - for cov_variable_thunk, cov_update_thunk, device in zip( - cov_variable_thunks_raw, cov_update_thunks_raw, - itertools.cycle(self._cov_devices)): - with tf_ops.device(device): - cov_variable_thunk() - cov_update_thunks.append(_make_thunk_on_device(cov_update_thunk, - device)) - else: - for cov_variable_thunk in cov_variable_thunks_raw: - cov_variable_thunk() - cov_update_thunks = cov_update_thunks_raw - - for inv_variable_thunk in inv_variable_thunks_raw: - inv_variable_thunk() - - if self._inv_devices: - inv_update_thunks = [] - for inv_update_thunk, device in zip(inv_update_thunks_raw, - itertools.cycle(self._inv_devices)): - inv_update_thunks.append(_make_thunk_on_device(inv_update_thunk, - device)) - else: - inv_update_thunks = inv_update_thunks_raw - - return cov_update_thunks, inv_update_thunks diff --git a/tensorflow/contrib/kfac/python/ops/utils.py b/tensorflow/contrib/kfac/python/ops/utils.py deleted file mode 100644 index 144295f4c7..0000000000 --- a/tensorflow/contrib/kfac/python/ops/utils.py +++ /dev/null @@ -1,709 +0,0 @@ -# Copyright 2017 The TensorFlow Authors. All Rights Reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -# ============================================================================== -"""Utility functions.""" - -from __future__ import absolute_import -from __future__ import division -from __future__ import print_function - -import numpy as np - -from tensorflow.contrib.tpu.python.ops import tpu_ops -from tensorflow.contrib.tpu.python.tpu import tpu_function -from tensorflow.python.framework import dtypes -from tensorflow.python.framework import ops -from tensorflow.python.framework import tensor_shape -from tensorflow.python.ops import array_ops -from tensorflow.python.ops import control_flow_ops -from tensorflow.python.ops import gradients_impl -from tensorflow.python.ops import linalg_ops -from tensorflow.python.ops import math_ops -from tensorflow.python.ops import nn_ops -from tensorflow.python.ops import random_ops -from tensorflow.python.ops import resource_variable_ops -from tensorflow.python.ops import variables - -# Method used for inverting matrices. -POSDEF_INV_METHOD = "cholesky" -POSDEF_EIG_METHOD = "self_adjoint" - - -def set_global_constants(posdef_inv_method=None): - """Sets various global constants used by the classes in this module.""" - global POSDEF_INV_METHOD - - if posdef_inv_method is not None: - POSDEF_INV_METHOD = posdef_inv_method - - -class SequenceDict(object): - """A dict convenience wrapper that allows getting/setting with sequences.""" - - def __init__(self, iterable=None): - self._dict = dict(iterable or []) - - def __getitem__(self, key_or_keys): - if isinstance(key_or_keys, (tuple, list)): - return list(map(self.__getitem__, key_or_keys)) - else: - return self._dict[key_or_keys] - - def __setitem__(self, key_or_keys, val_or_vals): - if isinstance(key_or_keys, (tuple, list)): - for key, value in zip(key_or_keys, val_or_vals): - self[key] = value - else: - self._dict[key_or_keys] = val_or_vals - - def items(self): - return list(self._dict.items()) - - -def tensors_to_column(tensors): - """Converts a tensor or list of tensors to a column vector. - - Args: - tensors: A tensor or list of tensors. - - Returns: - The tensors reshaped into vectors and stacked on top of each other. - """ - if isinstance(tensors, (tuple, list)): - return array_ops.concat( - tuple(array_ops.reshape(tensor, [-1, 1]) for tensor in tensors), axis=0) - else: - return array_ops.reshape(tensors, [-1, 1]) - - -def column_to_tensors(tensors_template, colvec): - """Converts a column vector back to the shape of the given template. - - Args: - tensors_template: A tensor or list of tensors. - colvec: A 2d column vector with the same shape as the value of - tensors_to_column(tensors_template). - - Returns: - X, where X is tensor or list of tensors with the properties: - 1) tensors_to_column(X) = colvec - 2) X (or its elements) have the same shape as tensors_template (or its - elements) - """ - if isinstance(tensors_template, (tuple, list)): - offset = 0 - tensors = [] - for tensor_template in tensors_template: - sz = np.prod(tensor_template.shape.as_list(), dtype=np.int32) - tensor = array_ops.reshape(colvec[offset:(offset + sz)], - tensor_template.shape) - tensors.append(tensor) - offset += sz - - tensors = tuple(tensors) - else: - tensors = array_ops.reshape(colvec, tensors_template.shape) - - return tensors - - -def kronecker_product(mat1, mat2): - """Computes the Kronecker product two matrices.""" - m1, n1 = mat1.get_shape().as_list() - mat1_rsh = array_ops.reshape(mat1, [m1, 1, n1, 1]) - m2, n2 = mat2.get_shape().as_list() - mat2_rsh = array_ops.reshape(mat2, [1, m2, 1, n2]) - return array_ops.reshape(mat1_rsh * mat2_rsh, [m1 * m2, n1 * n2]) - - -def layer_params_to_mat2d(vector): - """Converts a vector shaped like layer parameters to a 2D matrix. - - In particular, we reshape the weights/filter component of the vector to be - 2D, flattening all leading (input) dimensions. If there is a bias component, - we concatenate it to the reshaped weights/filter component. - - Args: - vector: A Tensor or pair of Tensors shaped like layer parameters. - - Returns: - A 2D Tensor with the same coefficients and the same output dimension. - """ - if isinstance(vector, (tuple, list)): - w_part, b_part = vector - w_part_reshaped = array_ops.reshape(w_part, - [-1, w_part.shape.as_list()[-1]]) - return array_ops.concat( - (w_part_reshaped, array_ops.reshape(b_part, [1, -1])), axis=0) - elif isinstance(vector, ops.IndexedSlices): - return vector - else: # Tensor or Tensor-like. - return array_ops.reshape(vector, [-1, vector.shape.as_list()[-1]]) - - -def mat2d_to_layer_params(vector_template, mat2d): - """Converts a canonical 2D matrix representation back to a vector. - - Args: - vector_template: A Tensor or pair of Tensors shaped like layer parameters. - mat2d: A 2D Tensor with the same shape as the value of - layer_params_to_mat2d(vector_template). - - Returns: - A Tensor or pair of Tensors with the same coefficients as mat2d and the same - shape as vector_template. - """ - if isinstance(vector_template, (tuple, list)): - w_part, b_part = mat2d[:-1], mat2d[-1] - return array_ops.reshape(w_part, vector_template[0].shape), b_part - elif isinstance(vector_template, ops.IndexedSlices): - if not isinstance(mat2d, ops.IndexedSlices): - raise TypeError( - "If vector_template is an IndexedSlices, so should mat2d.") - return mat2d - else: - return array_ops.reshape(mat2d, vector_template.shape) - - -def posdef_inv(tensor, damping): - """Computes the inverse of tensor + damping * identity.""" - identity = linalg_ops.eye(tensor.shape.as_list()[0], dtype=tensor.dtype) - damping = math_ops.cast(damping, dtype=tensor.dtype) - return posdef_inv_functions[POSDEF_INV_METHOD](tensor, identity, damping) - - -def posdef_inv_matrix_inverse(tensor, identity, damping): - """Computes inverse(tensor + damping * identity) directly.""" - return linalg_ops.matrix_inverse(tensor + damping * identity) - - -def posdef_inv_cholesky(tensor, identity, damping): - """Computes inverse(tensor + damping * identity) with Cholesky.""" - chol = linalg_ops.cholesky(tensor + damping * identity) - return linalg_ops.cholesky_solve(chol, identity) - - -def posdef_inv_eig(tensor, identity, damping): - """Computes inverse(tensor + damping * identity) with eigendecomposition.""" - eigenvalues, eigenvectors = linalg_ops.self_adjoint_eig( - tensor + damping * identity) - return math_ops.matmul( - eigenvectors / eigenvalues, eigenvectors, transpose_b=True) - - -posdef_inv_functions = { - "matrix_inverse": posdef_inv_matrix_inverse, - "cholesky": posdef_inv_cholesky, - "eig": posdef_inv_eig, -} - - -def posdef_eig(mat): - """Computes the eigendecomposition of a positive semidefinite matrix.""" - return posdef_eig_functions[POSDEF_EIG_METHOD](mat) - - -def posdef_eig_svd(mat): - """Computes the singular values and left singular vectors of a matrix.""" - evals, evecs, _ = linalg_ops.svd(mat) - - return evals, evecs - - -def posdef_eig_self_adjoint(mat): - """Computes eigendecomposition using self_adjoint_eig.""" - evals, evecs = linalg_ops.self_adjoint_eig(mat) - evals = math_ops.abs(evals) # Should be equivalent to svd approach. - - return evals, evecs - - -posdef_eig_functions = { - "self_adjoint": posdef_eig_self_adjoint, - "svd": posdef_eig_svd, -} - - -def cholesky(tensor, damping): - """Computes the inverse of tensor + damping * identity.""" - identity = linalg_ops.eye(tensor.shape.as_list()[0], dtype=tensor.dtype) - damping = math_ops.cast(damping, dtype=tensor.dtype) - return linalg_ops.cholesky(tensor + damping * identity) - - -class SubGraph(object): - """Defines a subgraph given by all the dependencies of a given set of outputs. - """ - - def __init__(self, outputs): - # Set of all ancestor Tensors, Ops to 'outputs'. - self._members = set() - - self._iter_add(outputs) - - def _iter_add(self, root): - """Iteratively adds all of nodes' ancestors using depth first search.""" - stack = [root] - while stack: - nodes = stack.pop() - for node in nodes: - if node in self._members: - continue - self._members.add(node) - - if isinstance(node, ops.Tensor): - stack.append((node.op,)) - elif isinstance(node, ops.Operation): - stack.append(node.inputs) - - def is_member(self, node): - """Check if 'node' is in this subgraph.""" - return node in self._members - - def variable_uses(self, var): - """Computes number of times a variable is used. - - Args: - var: Variable or ResourceVariable instance. - - Returns: - Number of times a variable is used within this subgraph. - - Raises: - ValueError: If 'var' is not a variable type. - """ - if isinstance(var, resource_variable_ops.ResourceVariable): - var = var.handle - elif isinstance(var, variables.Variable): - var = var.value() - else: - raise ValueError("%s does not appear to be a variable." % str(var)) - - return len(self._members.intersection(set(var.consumers()))) - - def filter_list(self, node_list): - """Filters 'node_list' to nodes in this subgraph.""" - filtered_list = [] - for node in node_list: - if self.is_member(node): - filtered_list.append(node) - return filtered_list - - -def generate_random_signs(shape, dtype=dtypes.float32): - """Generate a random tensor with {-1, +1} entries.""" - ints = random_ops.random_uniform(shape, maxval=2, dtype=dtypes.int32) - return 2 * math_ops.cast(ints, dtype=dtype) - 1 - - -def fwd_gradients(ys, xs, grad_xs=None, stop_gradients=None): - """Compute forward-mode gradients.""" - # See b/37888268. - - # This version of forward-mode autodiff is based on code by Tim Cooijmans - # and handles list arguments and certain special cases such as when the - # ys doesn't depend on one or more of the xs, and when ops.IndexedSlices are - # generated by the first gradients_impl.gradients call. - - us = [array_ops.zeros_like(y) + float("nan") for y in ys] - dydxs = gradients_impl.gradients( - ys, xs, grad_ys=us, stop_gradients=stop_gradients) - - # Deal with strange types that gradients_impl.gradients returns but can't - # deal with. - dydxs = [ - ops.convert_to_tensor(dydx) - if isinstance(dydx, ops.IndexedSlices) else dydx for dydx in dydxs - ] - dydxs = [ - array_ops.zeros_like(x) if dydx is None else dydx - for x, dydx in zip(xs, dydxs) - ] - - dysdx = gradients_impl.gradients(dydxs, us, grad_ys=grad_xs) - - return dysdx - - -def on_tpu(): - """Returns True when building a TPU computation.""" - return tpu_function.get_tpu_context().number_of_shards is not None - - -def cross_replica_mean(tensor, name=None): - """Takes mean value of a Tensor across all TPU cores. - - Args: - tensor: Tensor to be synchronized. - name: None or string. Name of Op. - - Returns: - Average of Tensor across all TPU cores. - - Raises: - ValueError: If called outside of TPU context. - """ - with ops.name_scope(name, "cross_replica_mean", [tensor]): - num_shards = tpu_function.get_tpu_context().number_of_shards - if num_shards is None: - raise ValueError( - "Cannot take cross_replica_mean() outside of TPU Context.") - if num_shards == 1: - return tensor - return tpu_ops.cross_replica_sum(tensor / num_shards) - - -def ensure_sequence(obj): - """If `obj` isn't a tuple or list, return a tuple containing `obj`.""" - if isinstance(obj, (tuple, list)): - return obj - else: - return (obj,) - - -def batch_execute(global_step, thunks, batch_size, name=None): - """Executes a subset of ops per global step. - - Given a list of thunks, each of which produces a single stateful op, - ensures that exactly 'batch_size' ops are run per global step. Ops are - scheduled in a round-robin fashion. For example, with 3 ops - - global_step | op0 | op1 | op2 - ------------+-----+-----+----- - 0 | x | x | - ------------+-----+-----+----- - 1 | x | | x - ------------+-----+-----+----- - 2 | | x | x - ------------+-----+-----+----- - 3 | x | x | - ------------+-----+-----+----- - 4 | x | | x - - Does not guarantee order of op execution within a single global step. - - Args: - global_step: Tensor indicating time. Determines which ops run. - thunks: List of thunks. Each thunk encapsulates one op. Return values are - ignored. - batch_size: int. Number of ops to execute per global_step. - name: string or None. Name scope for newly added ops. - - Returns: - List of ops. Exactly 'batch_size' ops are guaranteed to have an effect - every global step. - """ - - def true_fn(thunk): - """Ensures thunk is executed and returns an Op (not a Tensor).""" - - def result(): - with ops.control_dependencies([thunk()]): - return control_flow_ops.no_op() - - return result - - def false_fn(_): - """Executes a no-op.""" - - def result(): - return control_flow_ops.no_op() - - return result - - with ops.name_scope(name, "batch_execute"): - true_fns = [true_fn(thunk) for thunk in thunks] - false_fns = [false_fn(thunk) for thunk in thunks] - num_thunks = len(thunks) - conditions = [ - math_ops.less( - math_ops.mod(batch_size - 1 + global_step * batch_size - j, - num_thunks), batch_size) for j in range(num_thunks) - ] - result = [ - control_flow_ops.cond(condition, true_fn, false_fn) - for (condition, true_fn, - false_fn) in zip(conditions, true_fns, false_fns) - ] - return result - - -def extract_convolution_patches(inputs, - filter_shape, - padding, - strides=None, - dilation_rate=None, - name=None, - data_format=None): - """Extracts inputs to each output coordinate in tf.nn.convolution. - - This is a generalization of tf.extract_image_patches() to tf.nn.convolution(), - where the number of spatial dimensions may be something other than 2. - - Assumes, - - First dimension of inputs is batch_size - - Convolution filter is applied to all input channels. - - Args: - inputs: Tensor of shape [batch_size, ..spatial_image_shape.., - ..spatial_filter_shape.., in_channels]. Inputs to tf.nn.convolution(). - filter_shape: List of ints. Shape of filter passed to tf.nn.convolution(). - padding: string. Padding method. One of "VALID", "SAME". - strides: None or list of ints. Strides along spatial dimensions. - dilation_rate: None or list of ints. Dilation along spatial dimensions. - name: None or str. Name of Op. - data_format: None or str. Format of data. - - Returns: - Tensor of shape [batch_size, ..spatial_image_shape.., - ..spatial_filter_shape.., in_channels] - - Raises: - ValueError: If data_format does not put channel last. - ValueError: If inputs and filter disagree on in_channels. - """ - if not is_data_format_channel_last(data_format): - raise ValueError("Channel must be last dimension.") - with ops.name_scope(name, "extract_convolution_patches", - [inputs, filter_shape, padding, strides, dilation_rate]): - batch_size = inputs.shape.as_list()[0] - in_channels = inputs.shape.as_list()[-1] - - # filter_shape = spatial_filter_shape + [in_channels, out_channels] - spatial_filter_shape = filter_shape[:-2] - if in_channels != filter_shape[-2]: - raise ValueError("inputs and filter_shape must agree on in_channels.") - - # Map each input feature to a location in the output. - out_channels = np.prod(spatial_filter_shape) * in_channels - filters = linalg_ops.eye(out_channels) - filters = array_ops.reshape( - filters, - list(spatial_filter_shape) + [in_channels, out_channels]) - - result = nn_ops.convolution( - inputs, - filters, - padding=padding, - strides=strides, - dilation_rate=dilation_rate) - spatial_output_shape = result.shape.as_list()[1:-1] - result = array_ops.reshape(result, - [batch_size or -1] + spatial_output_shape + - list(spatial_filter_shape) + [in_channels]) - - return result - - -def extract_pointwise_conv2d_patches(inputs, - filter_shape, - name=None, - data_format=None): - """Extract patches for a 1x1 conv2d. - - Args: - inputs: 4-D Tensor of shape [batch_size, height, width, in_channels]. - filter_shape: List of 4 ints. Shape of filter to apply with conv2d() - name: None or str. Name for Op. - data_format: None or str. Format for data. See 'data_format' in - tf.nn.conv2d() for details. - - Returns: - Tensor of shape [batch_size, ..spatial_input_shape.., - ..spatial_filter_shape.., in_channels] - - Raises: - ValueError: if inputs is not 4-D. - ValueError: if filter_shape is not [1, 1, ?, ?] - ValueError: if data_format is not channels-last. - """ - if inputs.shape.ndims != 4: - raise ValueError("inputs must have 4 dims.") - if len(filter_shape) != 4: - raise ValueError("filter_shape must have 4 dims.") - if filter_shape[0] != 1 or filter_shape[1] != 1: - raise ValueError("filter_shape must have shape 1 along spatial dimensions.") - if not is_data_format_channel_last(data_format): - raise ValueError("data_format must be channels last.") - with ops.name_scope(name, "extract_pointwise_conv2d_patches", - [inputs, filter_shape]): - ksizes = [1, 1, 1, 1] # Spatial shape is 1x1. - strides = [1, 1, 1, 1] # Operate on all pixels. - rates = [1, 1, 1, 1] # Dilation has no meaning with spatial shape = 1. - padding = "VALID" # Doesn't matter. - result = array_ops.extract_image_patches(inputs, ksizes, strides, rates, - padding) - - batch_size, input_height, input_width, in_channels = inputs.shape.as_list() - filter_height, filter_width, in_channels, _ = filter_shape - return array_ops.reshape(result, [ - batch_size, input_height, input_width, filter_height, filter_width, - in_channels - ]) - - -def is_data_format_channel_last(data_format): - """True if data_format puts channel last.""" - if data_format is None: - return True - return data_format.endswith("C") - - -def matmul_sparse_dense(A, B, name=None, transpose_a=False, transpose_b=False): # pylint: disable=invalid-name - """Computes matmul(A, B) where A is sparse, B is dense. - - Args: - A: tf.IndexedSlices with dense shape [m, n]. - B: tf.Tensor with shape [n, k]. - name: str. Name of op. - transpose_a: Bool. If true we transpose A before multiplying it by B. - (Default: False) - transpose_b: Bool. If true we transpose B before multiplying it by A. - (Default: False) - - Returns: - tf.IndexedSlices resulting from matmul(A, B). - - Raises: - ValueError: If A doesn't represent a matrix. - ValueError: If B is not rank-2. - """ - with ops.name_scope(name, "matmul_sparse_dense", [A, B]): - if A.indices.shape.ndims != 1 or A.values.shape.ndims != 2: - raise ValueError("A must represent a matrix. Found: %s." % A) - if B.shape.ndims != 2: - raise ValueError("B must be a matrix.") - new_values = math_ops.matmul( - A.values, B, transpose_a=transpose_a, transpose_b=transpose_b) - return ops.IndexedSlices( - new_values, - A.indices, - dense_shape=array_ops.stack([A.dense_shape[0], new_values.shape[1]])) - - -def matmul_diag_sparse(A_diag, B, name=None): # pylint: disable=invalid-name - """Computes matmul(A, B) where A is a diagonal matrix, B is sparse. - - Args: - A_diag: diagonal entries of matrix A of shape [m, m]. - B: tf.IndexedSlices. Represents matrix of shape [m, n]. - name: str. Name of op. - - Returns: - tf.IndexedSlices resulting from matmul(A, B). - - Raises: - ValueError: If A_diag is not rank-1. - ValueError: If B doesn't represent a matrix. - """ - with ops.name_scope(name, "matmul_diag_sparse", [A_diag, B]): - A_diag = ops.convert_to_tensor(A_diag) - if A_diag.shape.ndims != 1: - raise ValueError("A_diag must be a rank-1 Tensor.") - if B.indices.shape.ndims != 1 or B.values.shape.ndims != 2: - raise ValueError("B must represent a matrix. Found: %s." % B) - a = array_ops.gather(A_diag, B.indices) - a = array_ops.reshape(a, list(a.shape) + [1] * (B.values.shape.ndims - 1)) - return ops.IndexedSlices(a * B.values, B.indices, dense_shape=B.dense_shape) - - -class PartitionedTensor(object): - """A Tensor partitioned across its 0-th dimension.""" - - def __init__(self, tensors): - """Initializes PartitionedTensor. - - Args: - tensors: List of Tensors. All Tensors must agree on shape (excepting - batch dimension) and dtype. - - Raises: - ValueError: If 'tensors' has length zero. - ValueError: if contents of 'tensors' don't agree on shape or dtype. - """ - if not tensors: - raise ValueError("tensors must be a list of 1+ Tensors.") - - dtype = tensors[0].dtype - if not all(tensor.dtype == dtype for tensor in tensors): - raise ValueError("all tensors must have dtype = %s." % dtype) - - shape = tensors[0].shape[1:] - if not all(tensor.shape[1:] == shape for tensor in tensors): - raise ValueError("All tensors must have shape = %s (excluding batch " - "dimension)." % shape) - - self.tensors = tensors - self._concats = {} # {device: Tensor} - - @property - def shape(self): - feature_shape = self.tensors[0].shape[1:] - batch_size = sum([tensor.shape[0] for tensor in self.tensors], - tensor_shape.Dimension(0)) - return tensor_shape.TensorShape([batch_size]).concatenate(feature_shape) - - def get_shape(self): - return self.shape - - @property - def dtype(self): - return self.tensors[0].dtype - - def __str__(self): - return "PartitionedTensor([%s, ...], dtype=%s, shape=%s)" % ( - self.tensors[0].name, self.dtype.name, tuple(self.shape.as_list())) - - def __hash__(self): - return hash(tuple(self.tensors)) - - def __eq__(self, other): - if not isinstance(other, PartitionedTensor): - return False - return self.tensors == other.tensors - - def __ne__(self, other): - return not self == other # pylint: disable=g-comparison-negation - - def __getitem__(self, key): - return self.as_tensor()[key] - - def as_tensor(self, dtype=None, name=None, as_ref=False): - with ops.name_scope(name, "PartitionedTensor.as_tensor", self.tensors): - assert not as_ref - assert dtype in [None, self.dtype] - result = array_ops.concat(self.tensors, axis=0) - - # Cache 'result' if we haven't already cached a value for this device. - if result.device not in self._concats: - self._concats[result.device] = result - return self._concats[result.device] - - @property - def device(self): - # PartitionedTensors in general do not live on a single device. If the - # device cannot be determined unambiguously this property will return None. - device = self.tensors[0].device - if all(tensor.device == device for tensor in self.tensors): - return device - return None - - -ops.register_tensor_conversion_function( - PartitionedTensor, - lambda val, dtype, name, as_ref: val.as_tensor(dtype, name, as_ref)) - - -# TODO(b/69623235): Add a function for finding tensors that share gradients -# to eliminate redundant fisher factor computations. diff --git a/tensorflow/contrib/kfac/python/ops/utils_lib.py b/tensorflow/contrib/kfac/python/ops/utils_lib.py deleted file mode 100644 index 330d222dbf..0000000000 --- a/tensorflow/contrib/kfac/python/ops/utils_lib.py +++ /dev/null @@ -1,50 +0,0 @@ -# Copyright 2017 The TensorFlow Authors. All Rights Reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -# ============================================================================== -"""Utility functions.""" - -from __future__ import absolute_import -from __future__ import division -from __future__ import print_function - -# pylint: disable=unused-import,line-too-long,wildcard-import -from tensorflow.contrib.kfac.python.ops.utils import * -from tensorflow.python.util.all_util import remove_undocumented -# pylint: enable=unused-import,line-too-long,wildcard-import - -_allowed_symbols = [ - "set_global_constants", - "SequenceDict", - "tensors_to_column", - "column_to_tensors", - "kronecker_product", - "layer_params_to_mat2d", - "mat2d_to_layer_params", - "posdef_inv", - "posdef_inv_matrix_inverse", - "posdef_inv_cholesky", - "posdef_inv_funcs", - "SubGraph", - "generate_random_signs", - "fwd_gradients", - "ensure_sequence", - "batch_execute", - "extract_convolution_patches", - "extract_pointwise_conv2d_patches", - "is_data_format_channel_last", - "matmul_sparse_dense", - "matmul_diag_sparse", -] - -remove_undocumented(__name__, allowed_exception_list=_allowed_symbols) |