From 7ece1c0b8e527d59d8082cd6428cd255e5700074 Mon Sep 17 00:00:00 2001 From: "A. Unique TensorFlower" Date: Wed, 1 Nov 2017 11:55:32 -0700 Subject: Moving model_pruning library to tf.contrib PiperOrigin-RevId: 174214419 --- tensorflow/BUILD | 1 + tensorflow/contrib/BUILD | 1 + tensorflow/contrib/__init__.py | 1 + tensorflow/contrib/cmake/tf_python.cmake | 5 + tensorflow/contrib/model_pruning/BUILD | 139 +++++ tensorflow/contrib/model_pruning/README.md | 195 +++++++ tensorflow/contrib/model_pruning/__init__.py | 46 ++ .../contrib/model_pruning/examples/cifar10/BUILD | 77 +++ .../model_pruning/examples/cifar10/cifar10_eval.py | 178 +++++++ .../examples/cifar10/cifar10_input.py | 256 +++++++++ .../examples/cifar10/cifar10_pruning.py | 395 ++++++++++++++ .../examples/cifar10/cifar10_train.py | 159 ++++++ .../model_pruning/python/layers/core_layers.py | 477 +++++++++++++++++ .../contrib/model_pruning/python/layers/layers.py | 364 +++++++++++++ .../model_pruning/python/layers/layers_test.py | 139 +++++ .../model_pruning/python/layers/rnn_cells.py | 340 ++++++++++++ .../model_pruning/python/layers/rnn_cells_test.py | 85 +++ .../contrib/model_pruning/python/learning.py | 188 +++++++ tensorflow/contrib/model_pruning/python/pruning.py | 585 +++++++++++++++++++++ .../contrib/model_pruning/python/pruning_test.py | 162 ++++++ 20 files changed, 3793 insertions(+) create mode 100644 tensorflow/contrib/model_pruning/BUILD create mode 100644 tensorflow/contrib/model_pruning/README.md create mode 100644 tensorflow/contrib/model_pruning/__init__.py create mode 100644 tensorflow/contrib/model_pruning/examples/cifar10/BUILD create mode 100644 tensorflow/contrib/model_pruning/examples/cifar10/cifar10_eval.py create mode 100644 tensorflow/contrib/model_pruning/examples/cifar10/cifar10_input.py create mode 100644 tensorflow/contrib/model_pruning/examples/cifar10/cifar10_pruning.py create mode 100644 tensorflow/contrib/model_pruning/examples/cifar10/cifar10_train.py create mode 100644 tensorflow/contrib/model_pruning/python/layers/core_layers.py create mode 100644 tensorflow/contrib/model_pruning/python/layers/layers.py create mode 100644 tensorflow/contrib/model_pruning/python/layers/layers_test.py create mode 100644 tensorflow/contrib/model_pruning/python/layers/rnn_cells.py create mode 100644 tensorflow/contrib/model_pruning/python/layers/rnn_cells_test.py create mode 100644 tensorflow/contrib/model_pruning/python/learning.py create mode 100644 tensorflow/contrib/model_pruning/python/pruning.py create mode 100644 tensorflow/contrib/model_pruning/python/pruning_test.py diff --git a/tensorflow/BUILD b/tensorflow/BUILD index 03cf745a36..f2cdf37dbf 100644 --- a/tensorflow/BUILD +++ b/tensorflow/BUILD @@ -413,6 +413,7 @@ filegroup( "//tensorflow/contrib/makefile:all_files", "//tensorflow/contrib/meta_graph_transform:all_files", "//tensorflow/contrib/metrics:all_files", + "//tensorflow/contrib/model_pruning:all_files", "//tensorflow/contrib/mpi_collectives:all_files", "//tensorflow/contrib/ndlstm:all_files", "//tensorflow/contrib/nearest_neighbor:all_files", diff --git a/tensorflow/contrib/BUILD b/tensorflow/contrib/BUILD index 2e9b96bb1d..3d53cbba56 100644 --- a/tensorflow/contrib/BUILD +++ b/tensorflow/contrib/BUILD @@ -57,6 +57,7 @@ py_library( "//tensorflow/contrib/memory_stats:memory_stats_py", "//tensorflow/contrib/meta_graph_transform", "//tensorflow/contrib/metrics:metrics_py", + "//tensorflow/contrib/model_pruning", "//tensorflow/contrib/nccl:nccl_py", "//tensorflow/contrib/ndlstm", "//tensorflow/contrib/nearest_neighbor:nearest_neighbor_py", diff --git a/tensorflow/contrib/__init__.py b/tensorflow/contrib/__init__.py index a26fdb982c..3068e9ed8f 100644 --- a/tensorflow/contrib/__init__.py +++ b/tensorflow/contrib/__init__.py @@ -51,6 +51,7 @@ from tensorflow.contrib import lookup from tensorflow.contrib import losses from tensorflow.contrib import memory_stats from tensorflow.contrib import metrics +from tensorflow.contrib import model_pruning from tensorflow.contrib import nccl from tensorflow.contrib import nn from tensorflow.contrib import opt diff --git a/tensorflow/contrib/cmake/tf_python.cmake b/tensorflow/contrib/cmake/tf_python.cmake index 277818b159..1c5fb5a97d 100755 --- a/tensorflow/contrib/cmake/tf_python.cmake +++ b/tensorflow/contrib/cmake/tf_python.cmake @@ -518,6 +518,11 @@ add_python_module("tensorflow/contrib/metrics/python") add_python_module("tensorflow/contrib/metrics/python/kernel_tests") add_python_module("tensorflow/contrib/metrics/python/metrics") add_python_module("tensorflow/contrib/metrics/python/ops") +add_python_module("tensorflow/contrib/model_pruning") +add_python_module("tensorflow/contrib/model_pruning/examples") +add_python_module("tensorflow/contrib/model_pruning/examples/cifar10") +add_python_module("tensorflow/contrib/model_pruning/python") +add_python_module("tensorflow/contrib/model_pruning/python/layers") add_python_module("tensorflow/contrib/ndlstm") add_python_module("tensorflow/contrib/ndlstm/python") add_python_module("tensorflow/contrib/nn") diff --git a/tensorflow/contrib/model_pruning/BUILD b/tensorflow/contrib/model_pruning/BUILD new file mode 100644 index 0000000000..ca3f13479e --- /dev/null +++ b/tensorflow/contrib/model_pruning/BUILD @@ -0,0 +1,139 @@ +# Copyright 2017 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +package(default_visibility = ["//tensorflow:__subpackages__"]) + +licenses(["notice"]) # Apache 2.0 + +load("//tensorflow:tensorflow.bzl", "py_test") + +py_library( + name = "core_layers", + srcs = ["python/layers/core_layers.py"], + srcs_version = "PY2AND3", + visibility = ["//visibility:public"], + deps = [ + "//tensorflow/python:layers", + "//tensorflow/python:ops", + "//tensorflow/python:platform", + ], +) + +py_library( + name = "layers", + srcs = ["python/layers/layers.py"], + srcs_version = "PY2AND3", + deps = [ + ":core_layers", + "//tensorflow/contrib/framework:framework_py", + "//tensorflow/contrib/layers:layers_py", + "//third_party/py/numpy", + ], +) + +py_test( + name = "layers_test", + size = "small", + srcs = ["python/layers/layers_test.py"], + srcs_version = "PY2AND3", + deps = [ + ":layers", + "//tensorflow/python:client_testlib", + ], +) + +py_library( + name = "learning", + srcs = ["python/learning.py"], + srcs_version = "PY2AND3", + deps = [ + "//tensorflow/contrib/slim", + ], +) + +py_library( + name = "rnn_cells", + srcs = ["python/layers/rnn_cells.py"], + srcs_version = "PY2AND3", + deps = [ + ":core_layers", + ], +) + +py_library( + name = "pruning", + srcs = ["python/pruning.py"], + srcs_version = "PY2AND3", + visibility = ["//visibility:public"], + deps = [ + ":core_layers", + "//tensorflow/contrib/training:training_py", + "//tensorflow/python:platform", + "//third_party/py/numpy", + ], +) + +py_test( + name = "pruning_test", + size = "small", + srcs = ["python/pruning_test.py"], + srcs_version = "PY2AND3", + deps = [ + ":pruning", + "//tensorflow/python:client_testlib", + ], +) + +py_test( + name = "rnn_cells_test", + size = "small", + srcs = ["python/layers/rnn_cells_test.py"], + srcs_version = "PY2AND3", + deps = [ + ":pruning", + ":rnn_cells", + "//tensorflow/python:client_testlib", + ], +) + +py_library( + name = "init_py", + srcs = ["__init__.py"], + srcs_version = "PY2AND3", +) + +# Top-level library +py_library( + name = "model_pruning", + srcs_version = "PY2AND3", + deps = [ + ":init_py", + ":layers", + ":learning", + ":pruning", + ":rnn_cells", + ], +) + +filegroup( + name = "all_files", + srcs = glob( + ["**/*"], + exclude = [ + "**/METADATA", + "**/OWNERS", + ], + ), + visibility = ["//tensorflow:__subpackages__"], +) diff --git a/tensorflow/contrib/model_pruning/README.md b/tensorflow/contrib/model_pruning/README.md new file mode 100644 index 0000000000..a8427e6014 --- /dev/null +++ b/tensorflow/contrib/model_pruning/README.md @@ -0,0 +1,195 @@ +# Model pruning: Training tensorflow models to have masked connections + +This document describes the API that facilitates magnitude-based pruning of +neural network's weight tensors. The API helps inject necessary tensorflow op +into the training graph so the model can be pruned while it is being trained. + +### Model creation + +The first step involves adding mask and threshold variables to the layers that +need to undergo pruning. The variable mask is the same shape as the layer's +weight tensor and determines which of the weights participate in the forward +execution of the graph. This can be achieved by wrapping the weight tensor of +the layer with the `apply_mask` function provided in +[pruning.py](https://github.com/tensorflow/tensorflow/tree/master/tensorflow/contrib/model_pruning/python/pruning.py). +For example: + +```python +conv = tf.nn.conv2d(images, pruning.apply_mask(weights), stride, padding) +``` + +This creates a convolutional layer with additional variables mask and threshold +as shown below: ![Convolutional layer with mask and +threshold](./mask.png "Convolutional layer with mask and threshold") + +Alternatively, the API also provides variant of tensorflow layers with these +auxiliary variables built-in (see +[layers](https://github.com/tensorflow/tensorflow/tree/master/tensorflow/contrib/model_pruning/python/layers)) +. Layers currently supported: + +* [layers.masked_conv2d](https://github.com/tensorflow/tensorflow/tree/master/tensorflow/contrib/model_pruning/python/layers/layers.py?l=83) + +* [layers.masked_fully_connected](https://github.com/tensorflow/tensorflow/tree/master/tensorflow/contrib/model_pruning/python/layers/layers.py?l=241) + +* [rnn_cells.MaskedLSTMCell](https://github.com/tensorflow/tensorflow/tree/master/tensorflow/contrib/model_pruning/python/layers/rnn_cells.py?l=154) + +### Adding pruning ops to the training graph + +The pruning library allows for specification of the following hyper parameters: + +| Hyperparameter | Type | Default | Description | +| ---------------------------- | ------- | ------------- | -------------- | +| name | string | model_pruning | Name of the | +: : : : pruning : +: : : : specification. : +: : : : Used for : +: : : : adding : +: : : : summaries and : +: : : : ops under a : +: : : : common : +: : : : tensorflow : +: : : : name_scope : +| begin_pruning_step | integer | 0 | The global | +: : : : step at which : +: : : : to begin : +: : : : pruning : +| end_pruning_step | integer | -1 | The global | +: : : : step at which : +: : : : to terminate : +: : : : pruning. : +: : : : Defaults to -1 : +: : : : implying that : +: : : : pruning : +: : : : continues till : +: : : : the training : +: : : : stops : +| do_not_prune | list of | [""] | list of layers | +: : strings : : that are not : +: : : : pruned : +| threshold_decay | float | 0.9 | The decay | +: : : : factor to use : +: : : : for : +: : : : exponential : +: : : : decay of the : +: : : : thresholds : +| pruning_frequency | integer | 10 | How often | +: : : : should the : +: : : : masks be : +: : : : updated? (in # : +: : : : of : +: : : : global_steps). : +| nbins | integer | 255 | Number of bins | +: : : : to use for : +: : : : histogram : +: : : : computation : +| initial_sparsity | float | 0.0 | Initial | +: : : : sparsity value : +| target_sparsity | float | 0.5 | Target | +: : : : sparsity value : +| sparsity_function_begin_step | integer | 0 | The global | +: : : : step at this : +: : : : which the : +: : : : gradual : +: : : : sparsity : +: : : : function : +: : : : begins to take : +: : : : effect : +| sparsity_function_end_step | integer | 100 | The global | +: : : : step used as : +: : : : the end point : +: : : : for the : +: : : : gradual : +: : : : sparsity : +: : : : function : +| sparsity_function_exponent | float | 3.0 | exponent = 1 | +: : : : is linearly : +: : : : varying : +: : : : sparsity : +: : : : between : +: : : : initial and : +: : : : final. : +: : : : exponent > 1 : +: : : : varies more : +: : : : slowly towards : +: : : : the end than : +: : : : the beginning : + +The sparsity $$s_t$$ at global step $$t$$ is given by: + +$$ s_{t}=s_{f}+\left(s_{i}-s_{f}\right)\left(1-\frac{t-t_{0}}{n\Delta t}\right)^{3} $$ + +The interval between sparsity_function_begin_step and sparsity_function_end_step +is divided into $$n$$ intervals of size equal to the pruning_frequency ($$\Delta +t$$). $$s_f$$ is the target_sparsity, $$s_i$$ is the initial_sparsity, $$t_0$$ +is the sparsity_function_begin_step. In this equation, the +sparsity_function_exponent is set to 3. +### Adding pruning ops to the training graph + +The final step involves adding ops to the training graph that monitors the +distribution of the layer's weight magnitudes and determines the layer threshold +such masking all the weights below this threshold achieves the sparsity level +desired for the current training step. This can be achieved as follows: + +```python +tf.app.flags.DEFINE_string( + 'pruning_hparams', '', + """Comma separated list of pruning-related hyperparameters""") + +with tf.graph.as_default(): + + # Create global step variable + global_step = tf.train.get_global_step() + + # Parse pruning hyperparameters + pruning_hparams = pruning.get_pruning_hparams().parse(FLAGS.pruning_hparams) + + # Create a pruning object using the pruning specification + p = pruning.Pruning(pruning_hparams, global_step=global_step) + + # Add conditional mask update op. Executing this op will update all + # the masks in the graph if the current global step is in the range + # [begin_pruning_step, end_pruning_step] as specified by the pruning spec + mask_update_op = p.conditional_mask_update_op() + + # Add summaries to keep track of the sparsity in different layers during training + p.add_pruning_summaries() + + with tf.train.MonitoredTrainingSession(...) as mon_sess: + # Run the usual training op in the tf session + mon_sess.run(train_op) + + # Update the masks by running the mask_update_op + mon_sess.run(mask_update_op) + +``` + +## Example: Pruning and training deep CNNs on the cifar10 dataset + +Please see https://www.tensorflow.org/tutorials/deep_cnn for details on neural +network architecture, setting up inputs etc. The additional changes needed to +incorporate pruning are captured in the following: + +* [cifar10_pruning.py](https://github.com/tensorflow/tensorflow/tree/master/tensorflow/contrib/model_pruning/examples/cifar10/cifar10_pruning.py) + creates a deep CNN with the same architecture, but adds mask and threshold + variables for each of the weight tensors in the convolutional and + locally-connected layers. + +* [cifar10_train.py](https://github.com/tensorflow/tensorflow/tree/master/tensorflow/contrib/model_pruning/examples/cifar10/cifar10_train.py) + add pruning ops to the training graph as described above. + +To train the pruned version of cifar10: + +```bash +$ examples_dir=contrib/model_pruning/examples +$ bazel build -c opt $examples_dir/cifar10:cifar10_{train,eval} +$ bazel-bin/$examples_dir/cifar10/cifar10_train --pruning_hparams=name=cifar10_pruning,begin_pruning_step=10000,end_pruning_step=100000,target_sparsity=0.9,sparsity_function_begin_step=10000,sparsity_function_end_step=100000 +``` + +Eval: + +```shell +$ bazel-bin/$examples_dir/cifar10/cifar10_eval --run_once +``` + +TODO(suyoggupta): Add figures showing the sparsity function, sparsity for +different layers etc. diff --git a/tensorflow/contrib/model_pruning/__init__.py b/tensorflow/contrib/model_pruning/__init__.py new file mode 100644 index 0000000000..aaeb2238a4 --- /dev/null +++ b/tensorflow/contrib/model_pruning/__init__.py @@ -0,0 +1,46 @@ +# Copyright 2017 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Model pruning implementation in tensorflow.""" + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +# pylint: disable=unused-import +from tensorflow.contrib.model_pruning.python.layers.layers import masked_conv2d +from tensorflow.contrib.model_pruning.python.layers.layers import masked_convolution +from tensorflow.contrib.model_pruning.python.layers.layers import masked_fully_connected +from tensorflow.contrib.model_pruning.python.layers.rnn_cells import MaskedBasicLSTMCell +from tensorflow.contrib.model_pruning.python.layers.rnn_cells import MaskedLSTMCell +from tensorflow.contrib.model_pruning.python.learning import train +from tensorflow.contrib.model_pruning.python.pruning import apply_mask +from tensorflow.contrib.model_pruning.python.pruning import get_masked_weights +from tensorflow.contrib.model_pruning.python.pruning import get_masks +from tensorflow.contrib.model_pruning.python.pruning import get_thresholds +from tensorflow.contrib.model_pruning.python.pruning import get_weight_sparsity +from tensorflow.contrib.model_pruning.python.pruning import get_weights +from tensorflow.contrib.model_pruning.python.pruning import Pruning +# pylint: enable=unused-import + +from tensorflow.python.util.all_util import remove_undocumented + +_allowed_symbols = [ + 'masked_convolution', 'masked_conv2d', 'masked_fully_connected', + 'MaskedBasicLSTMCell', 'MaskedLSTMCell', 'train', 'apply_mask', + 'get_masked_weights', 'get_masks', 'get_thresholds', 'get_weights', + 'get_weight_sparsity', 'Pruning' +] + +remove_undocumented(__name__, _allowed_symbols) diff --git a/tensorflow/contrib/model_pruning/examples/cifar10/BUILD b/tensorflow/contrib/model_pruning/examples/cifar10/BUILD new file mode 100644 index 0000000000..299278ae75 --- /dev/null +++ b/tensorflow/contrib/model_pruning/examples/cifar10/BUILD @@ -0,0 +1,77 @@ +# Copyright 2017 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +# Description: +# Example TensorFlow models for CIFAR-10 + +package( + default_visibility = [ + "//tensorflow:internal", + ], +) + +licenses(["notice"]) # Apache 2.0 + +py_library( + name = "cifar10_input", + srcs = ["cifar10_input.py"], + srcs_version = "PY2AND3", + deps = [ + "//tensorflow:tensorflow_py", + ], +) + +py_library( + name = "cifar10_pruning", + srcs = ["cifar10_pruning.py"], + srcs_version = "PY2AND3", + deps = [ + ":cifar10_input", + "//tensorflow:tensorflow_py", + ], +) + +py_binary( + name = "cifar10_eval", + srcs = [ + "cifar10_eval.py", + ], + srcs_version = "PY2AND3", + deps = [ + ":cifar10_pruning", + ], +) + +py_binary( + name = "cifar10_train", + srcs = [ + "cifar10_train.py", + ], + srcs_version = "PY2AND3", + deps = [ + ":cifar10_pruning", + ], +) + +filegroup( + name = "all_files", + srcs = glob( + ["**/*"], + exclude = [ + "**/METADATA", + "**/OWNERS", + ], + ), + visibility = ["//tensorflow:__subpackages__"], +) diff --git a/tensorflow/contrib/model_pruning/examples/cifar10/cifar10_eval.py b/tensorflow/contrib/model_pruning/examples/cifar10/cifar10_eval.py new file mode 100644 index 0000000000..d72b2a1dca --- /dev/null +++ b/tensorflow/contrib/model_pruning/examples/cifar10/cifar10_eval.py @@ -0,0 +1,178 @@ +# Copyright 2017 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Evaluation for CIFAR-10. + +Accuracy: +cifar10_train.py achieves 83.0% accuracy after 100K steps (256 epochs +of data) as judged by cifar10_eval.py. + +Speed: +On a single Tesla K40, cifar10_train.py processes a single batch of 128 images +in 0.25-0.35 sec (i.e. 350 - 600 images /sec). The model reaches ~86% +accuracy after 100K steps in 8 hours of training time. + +Usage: +Please see the tutorial and website for how to download the CIFAR-10 +data set, compile the program and train the model. + +http://tensorflow.org/tutorials/deep_cnn/ +""" +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import argparse +import datetime +import math +import sys +import time + +import numpy as np +import tensorflow as tf + +from tensorflow.contrib.model_pruning.examples.cifar10 import cifar10_pruning as cifar10 + +FLAGS = None + + +def eval_once(saver, summary_writer, top_k_op, summary_op): + """Run Eval once. + + Args: + saver: Saver. + summary_writer: Summary writer. + top_k_op: Top K op. + summary_op: Summary op. + """ + with tf.Session() as sess: + ckpt = tf.train.get_checkpoint_state(FLAGS.checkpoint_dir) + if ckpt and ckpt.model_checkpoint_path: + # Restores from checkpoint + saver.restore(sess, ckpt.model_checkpoint_path) + # Assuming model_checkpoint_path looks something like: + # /my-favorite-path/cifar10_train/model.ckpt-0, + # extract global_step from it. + global_step = ckpt.model_checkpoint_path.split('/')[-1].split('-')[-1] + else: + print('No checkpoint file found') + return + + # Start the queue runners. + coord = tf.train.Coordinator() + try: + threads = [] + for qr in tf.get_collection(tf.GraphKeys.QUEUE_RUNNERS): + threads.extend(qr.create_threads(sess, coord=coord, daemon=True, + start=True)) + + num_iter = int(math.ceil(FLAGS.num_examples / 128)) + true_count = 0 # Counts the number of correct predictions. + total_sample_count = num_iter * 128 + step = 0 + while step < num_iter and not coord.should_stop(): + predictions = sess.run([top_k_op]) + true_count += np.sum(predictions) + step += 1 + + # Compute precision @ 1. + precision = true_count / total_sample_count + print('%s: precision @ 1 = %.3f' % (datetime.datetime.now(), precision)) + + summary = tf.Summary() + summary.ParseFromString(sess.run(summary_op)) + summary.value.add(tag='Precision @ 1', simple_value=precision) + summary_writer.add_summary(summary, global_step) + except Exception as e: # pylint: disable=broad-except + coord.request_stop(e) + + coord.request_stop() + coord.join(threads, stop_grace_period_secs=10) + + +def evaluate(): + """Eval CIFAR-10 for a number of steps.""" + with tf.Graph().as_default() as g: + # Get images and labels for CIFAR-10. + eval_data = FLAGS.eval_data == 'test' + images, labels = cifar10.inputs(eval_data=eval_data) + + # Build a Graph that computes the logits predictions from the + # inference model. + logits = cifar10.inference(images) + + # Calculate predictions. + top_k_op = tf.nn.in_top_k(logits, labels, 1) + + # Restore the moving average version of the learned variables for eval. + variable_averages = tf.train.ExponentialMovingAverage( + cifar10.MOVING_AVERAGE_DECAY) + variables_to_restore = variable_averages.variables_to_restore() + saver = tf.train.Saver(variables_to_restore) + + # Build the summary operation based on the TF collection of Summaries. + summary_op = tf.summary.merge_all() + + summary_writer = tf.summary.FileWriter(FLAGS.eval_dir, g) + + while True: + eval_once(saver, summary_writer, top_k_op, summary_op) + if FLAGS.run_once: + break + time.sleep(FLAGS.eval_interval_secs) + + +def main(argv=None): # pylint: disable=unused-argument + cifar10.maybe_download_and_extract() + if tf.gfile.Exists(FLAGS.eval_dir): + tf.gfile.DeleteRecursively(FLAGS.eval_dir) + tf.gfile.MakeDirs(FLAGS.eval_dir) + evaluate() + + +if __name__ == '__main__': + parser = argparse.ArgumentParser() + parser.add_argument( + '--eval_dir', + type=str, + default='/tmp/cifar10_eval', + help='Directory where to write event logs.') + parser.add_argument( + '--eval_data', + type=str, + default='test', + help="""Either 'test' or 'train_eval'.""") + parser.add_argument( + '--checkpoint_dir', + type=str, + default='/tmp/cifar10_train', + help="""Directory where to read model checkpoints.""") + parser.add_argument( + '--eval_interval_secs', + type=int, + default=60 * 5, + help='How often to run the eval.') + parser.add_argument( + '--num_examples', + type=int, + default=10000, + help='Number of examples to run.') + parser.add_argument( + '--run_once', + type=bool, + default=False, + help='Whether to run eval only once.') + + FLAGS, unparsed = parser.parse_known_args() + tf.app.run(main=main, argv=[sys.argv[0]] + unparsed) diff --git a/tensorflow/contrib/model_pruning/examples/cifar10/cifar10_input.py b/tensorflow/contrib/model_pruning/examples/cifar10/cifar10_input.py new file mode 100644 index 0000000000..d07fece4bc --- /dev/null +++ b/tensorflow/contrib/model_pruning/examples/cifar10/cifar10_input.py @@ -0,0 +1,256 @@ +# Copyright 2017 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Routine for decoding the CIFAR-10 binary file format.""" + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import os + +from six.moves import xrange # pylint: disable=redefined-builtin +import tensorflow as tf + +# Process images of this size. Note that this differs from the original CIFAR +# image size of 32 x 32. If one alters this number, then the entire model +# architecture will change and any model would need to be retrained. +IMAGE_SIZE = 24 + +# Global constants describing the CIFAR-10 data set. +NUM_CLASSES = 10 +NUM_EXAMPLES_PER_EPOCH_FOR_TRAIN = 50000 +NUM_EXAMPLES_PER_EPOCH_FOR_EVAL = 10000 + + +def read_cifar10(filename_queue): + """Reads and parses examples from CIFAR10 data files. + + Recommendation: if you want N-way read parallelism, call this function + N times. This will give you N independent Readers reading different + files & positions within those files, which will give better mixing of + examples. + + Args: + filename_queue: A queue of strings with the filenames to read from. + + Returns: + An object representing a single example, with the following fields: + height: number of rows in the result (32) + width: number of columns in the result (32) + depth: number of color channels in the result (3) + key: a scalar string Tensor describing the filename & record number + for this example. + label: an int32 Tensor with the label in the range 0..9. + uint8image: a [height, width, depth] uint8 Tensor with the image data + """ + + class CIFAR10Record(object): + pass + result = CIFAR10Record() + + # Dimensions of the images in the CIFAR-10 dataset. + # See http://www.cs.toronto.edu/~kriz/cifar.html for a description of the + # input format. + label_bytes = 1 # 2 for CIFAR-100 + result.height = 32 + result.width = 32 + result.depth = 3 + image_bytes = result.height * result.width * result.depth + # Every record consists of a label followed by the image, with a + # fixed number of bytes for each. + record_bytes = label_bytes + image_bytes + + # Read a record, getting filenames from the filename_queue. No + # header or footer in the CIFAR-10 format, so we leave header_bytes + # and footer_bytes at their default of 0. + reader = tf.FixedLengthRecordReader(record_bytes=record_bytes) + result.key, value = reader.read(filename_queue) + + # Convert from a string to a vector of uint8 that is record_bytes long. + record_bytes = tf.decode_raw(value, tf.uint8) + + # The first bytes represent the label, which we convert from uint8->int32. + result.label = tf.cast( + tf.strided_slice(record_bytes, [0], [label_bytes]), tf.int32) + + # The remaining bytes after the label represent the image, which we reshape + # from [depth * height * width] to [depth, height, width]. + depth_major = tf.reshape( + tf.strided_slice(record_bytes, [label_bytes], + [label_bytes + image_bytes]), + [result.depth, result.height, result.width]) + # Convert from [depth, height, width] to [height, width, depth]. + result.uint8image = tf.transpose(depth_major, [1, 2, 0]) + + return result + + +def _generate_image_and_label_batch(image, label, min_queue_examples, + batch_size, shuffle): + """Construct a queued batch of images and labels. + + Args: + image: 3-D Tensor of [height, width, 3] of type.float32. + label: 1-D Tensor of type.int32 + min_queue_examples: int32, minimum number of samples to retain + in the queue that provides of batches of examples. + batch_size: Number of images per batch. + shuffle: boolean indicating whether to use a shuffling queue. + + Returns: + images: Images. 4D tensor of [batch_size, height, width, 3] size. + labels: Labels. 1D tensor of [batch_size] size. + """ + # Create a queue that shuffles the examples, and then + # read 'batch_size' images + labels from the example queue. + num_preprocess_threads = 16 + if shuffle: + images, label_batch = tf.train.shuffle_batch( + [image, label], + batch_size=batch_size, + num_threads=num_preprocess_threads, + capacity=min_queue_examples + 3 * batch_size, + min_after_dequeue=min_queue_examples) + else: + images, label_batch = tf.train.batch( + [image, label], + batch_size=batch_size, + num_threads=num_preprocess_threads, + capacity=min_queue_examples + 3 * batch_size) + + # Display the training images in the visualizer. + tf.summary.image('images', images) + + return images, tf.reshape(label_batch, [batch_size]) + + +def distorted_inputs(data_dir, batch_size): + """Construct distorted input for CIFAR training using the Reader ops. + + Args: + data_dir: Path to the CIFAR-10 data directory. + batch_size: Number of images per batch. + + Returns: + images: Images. 4D tensor of [batch_size, IMAGE_SIZE, IMAGE_SIZE, 3] size. + labels: Labels. 1D tensor of [batch_size] size. + """ + filenames = [os.path.join(data_dir, 'data_batch_%d.bin' % i) + for i in xrange(1, 6)] + for f in filenames: + if not tf.gfile.Exists(f): + raise ValueError('Failed to find file: ' + f) + + # Create a queue that produces the filenames to read. + filename_queue = tf.train.string_input_producer(filenames) + + # Read examples from files in the filename queue. + read_input = read_cifar10(filename_queue) + reshaped_image = tf.cast(read_input.uint8image, tf.float32) + + height = IMAGE_SIZE + width = IMAGE_SIZE + + # Image processing for training the network. Note the many random + # distortions applied to the image. + + # Randomly crop a [height, width] section of the image. + distorted_image = tf.random_crop(reshaped_image, [height, width, 3]) + + # Randomly flip the image horizontally. + distorted_image = tf.image.random_flip_left_right(distorted_image) + + # Because these operations are not commutative, consider randomizing + # the order their operation. + distorted_image = tf.image.random_brightness(distorted_image, + max_delta=63) + distorted_image = tf.image.random_contrast(distorted_image, + lower=0.2, upper=1.8) + + # Subtract off the mean and divide by the variance of the pixels. + float_image = tf.image.per_image_standardization(distorted_image) + + # Set the shapes of tensors. + float_image.set_shape([height, width, 3]) + read_input.label.set_shape([1]) + + # Ensure that the random shuffling has good mixing properties. + min_fraction_of_examples_in_queue = 0.4 + min_queue_examples = int(NUM_EXAMPLES_PER_EPOCH_FOR_TRAIN * + min_fraction_of_examples_in_queue) + print ('Filling queue with %d CIFAR images before starting to train. ' + 'This will take a few minutes.' % min_queue_examples) + + # Generate a batch of images and labels by building up a queue of examples. + return _generate_image_and_label_batch(float_image, read_input.label, + min_queue_examples, batch_size, + shuffle=True) + + +def inputs(eval_data, data_dir, batch_size): + """Construct input for CIFAR evaluation using the Reader ops. + + Args: + eval_data: bool, indicating if one should use the train or eval data set. + data_dir: Path to the CIFAR-10 data directory. + batch_size: Number of images per batch. + + Returns: + images: Images. 4D tensor of [batch_size, IMAGE_SIZE, IMAGE_SIZE, 3] size. + labels: Labels. 1D tensor of [batch_size] size. + """ + if not eval_data: + filenames = [os.path.join(data_dir, 'data_batch_%d.bin' % i) + for i in xrange(1, 6)] + num_examples_per_epoch = NUM_EXAMPLES_PER_EPOCH_FOR_TRAIN + else: + filenames = [os.path.join(data_dir, 'test_batch.bin')] + num_examples_per_epoch = NUM_EXAMPLES_PER_EPOCH_FOR_EVAL + + for f in filenames: + if not tf.gfile.Exists(f): + raise ValueError('Failed to find file: ' + f) + + # Create a queue that produces the filenames to read. + filename_queue = tf.train.string_input_producer(filenames) + + # Read examples from files in the filename queue. + read_input = read_cifar10(filename_queue) + reshaped_image = tf.cast(read_input.uint8image, tf.float32) + + height = IMAGE_SIZE + width = IMAGE_SIZE + + # Image processing for evaluation. + # Crop the central [height, width] of the image. + resized_image = tf.image.resize_image_with_crop_or_pad(reshaped_image, + width, height) + + # Subtract off the mean and divide by the variance of the pixels. + float_image = tf.image.per_image_standardization(resized_image) + + # Set the shapes of tensors. + float_image.set_shape([height, width, 3]) + read_input.label.set_shape([1]) + + # Ensure that the random shuffling has good mixing properties. + min_fraction_of_examples_in_queue = 0.4 + min_queue_examples = int(num_examples_per_epoch * + min_fraction_of_examples_in_queue) + + # Generate a batch of images and labels by building up a queue of examples. + return _generate_image_and_label_batch(float_image, read_input.label, + min_queue_examples, batch_size, + shuffle=False) diff --git a/tensorflow/contrib/model_pruning/examples/cifar10/cifar10_pruning.py b/tensorflow/contrib/model_pruning/examples/cifar10/cifar10_pruning.py new file mode 100644 index 0000000000..0d1de869f6 --- /dev/null +++ b/tensorflow/contrib/model_pruning/examples/cifar10/cifar10_pruning.py @@ -0,0 +1,395 @@ +# Copyright 2017 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Builds the CIFAR-10 network with additional variables to support pruning. + +Summary of available functions: + + # Compute input images and labels for training. If you would like to run + # evaluations, use inputs() instead. + inputs, labels = distorted_inputs() + + # Compute inference on the model inputs to make a prediction. + predictions = inference(inputs) + + # Compute the total loss of the prediction with respect to the labels. + loss = loss(predictions, labels) + + # Create a graph to run one step of training with respect to the loss. + train_op = train(loss, global_step) +""" +# pylint: disable=missing-docstring +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import os +import re +import sys +import tarfile + +from six.moves import urllib +import tensorflow as tf + +from tensorflow.contrib.model_pruning.examples.cifar10 import cifar10_input +from tensorflow.contrib.model_pruning.python import pruning + +# Global constants describing the CIFAR-10 data set. +IMAGE_SIZE = cifar10_input.IMAGE_SIZE +NUM_CLASSES = cifar10_input.NUM_CLASSES +NUM_EXAMPLES_PER_EPOCH_FOR_TRAIN = cifar10_input.NUM_EXAMPLES_PER_EPOCH_FOR_TRAIN +NUM_EXAMPLES_PER_EPOCH_FOR_EVAL = cifar10_input.NUM_EXAMPLES_PER_EPOCH_FOR_EVAL +BATCH_SIZE = 128 +DATA_DIR = '/tmp/cifar10_data' + +# Constants describing the training process. +MOVING_AVERAGE_DECAY = 0.9999 # The decay to use for the moving average. +NUM_EPOCHS_PER_DECAY = 350.0 # Epochs after which learning rate decays. +LEARNING_RATE_DECAY_FACTOR = 0.1 # Learning rate decay factor. +INITIAL_LEARNING_RATE = 0.1 # Initial learning rate. + +# If a model is trained with multiple GPUs, prefix all Op names with tower_name +# to differentiate the operations. Note that this prefix is removed from the +# names of the summaries when visualizing a model. +TOWER_NAME = 'tower' + +DATA_URL = 'http://www.cs.toronto.edu/~kriz/cifar-10-binary.tar.gz' + + +def _activation_summary(x): + """Helper to create summaries for activations. + + Creates a summary that provides a histogram of activations. + Creates a summary that measures the sparsity of activations. + + Args: + x: Tensor + Returns: + nothing + """ + # Remove 'tower_[0-9]/' from the name in case this is a multi-GPU training + # session. This helps the clarity of presentation on tensorboard. + tensor_name = re.sub('%s_[0-9]*/' % TOWER_NAME, '', x.op.name) + tf.summary.histogram(tensor_name + '/activations', x) + tf.summary.scalar(tensor_name + '/sparsity', + tf.nn.zero_fraction(x)) + + +def _variable_on_cpu(name, shape, initializer): + """Helper to create a Variable stored on CPU memory. + + Args: + name: name of the variable + shape: list of ints + initializer: initializer for Variable + + Returns: + Variable Tensor + """ + with tf.device('/cpu:0'): + dtype = tf.float32 + var = tf.get_variable(name, shape, initializer=initializer, dtype=dtype) + return var + + +def _variable_with_weight_decay(name, shape, stddev, wd): + """Helper to create an initialized Variable with weight decay. + + Note that the Variable is initialized with a truncated normal distribution. + A weight decay is added only if one is specified. + + Args: + name: name of the variable + shape: list of ints + stddev: standard deviation of a truncated Gaussian + wd: add L2Loss weight decay multiplied by this float. If None, weight + decay is not added for this Variable. + + Returns: + Variable Tensor + """ + dtype = tf.float32 + var = _variable_on_cpu( + name, + shape, + tf.truncated_normal_initializer(stddev=stddev, dtype=dtype)) + if wd is not None: + weight_decay = tf.multiply(tf.nn.l2_loss(var), wd, name='weight_loss') + tf.add_to_collection('losses', weight_decay) + return var + + +def distorted_inputs(): + """Construct distorted input for CIFAR training using the Reader ops. + + Returns: + images: Images. 4D tensor of [batch_size, IMAGE_SIZE, IMAGE_SIZE, 3] size. + labels: Labels. 1D tensor of [batch_size] size. + + Raises: + ValueError: If no data_dir + """ + if not DATA_DIR: + raise ValueError('Please supply a data_dir') + data_dir = os.path.join(DATA_DIR, 'cifar-10-batches-bin') + images, labels = cifar10_input.distorted_inputs( + data_dir=data_dir, batch_size=BATCH_SIZE) + return images, labels + + +def inputs(eval_data): + """Construct input for CIFAR evaluation using the Reader ops. + + Args: + eval_data: bool, indicating if one should use the train or eval data set. + + Returns: + images: Images. 4D tensor of [batch_size, IMAGE_SIZE, IMAGE_SIZE, 3] size. + labels: Labels. 1D tensor of [batch_size] size. + + Raises: + ValueError: If no data_dir + """ + if not DATA_DIR: + raise ValueError('Please supply a data_dir') + data_dir = os.path.join(DATA_DIR, 'cifar-10-batches-bin') + images, labels = cifar10_input.inputs( + eval_data=eval_data, data_dir=data_dir, batch_size=BATCH_SIZE) + return images, labels + + +def inference(images): + """Build the CIFAR-10 model. + + Args: + images: Images returned from distorted_inputs() or inputs(). + + Returns: + Logits. + """ + # We instantiate all variables using tf.get_variable() instead of + # tf.Variable() in order to share variables across multiple GPU training runs. + # If we only ran this model on a single GPU, we could simplify this function + # by replacing all instances of tf.get_variable() with tf.Variable(). + # + # While instantiating conv and local layers, we add mask and threshold + # variables to the layer by calling the pruning.apply_mask() function. + # Note that the masks are applied only to the weight tensors + # conv1 + with tf.variable_scope('conv1') as scope: + kernel = _variable_with_weight_decay('weights', + shape=[5, 5, 3, 64], + stddev=5e-2, + wd=0.0) + + conv = tf.nn.conv2d( + images, pruning.apply_mask(kernel, scope), [1, 1, 1, 1], padding='SAME') + biases = _variable_on_cpu('biases', [64], tf.constant_initializer(0.0)) + pre_activation = tf.nn.bias_add(conv, biases) + conv1 = tf.nn.relu(pre_activation, name=scope.name) + _activation_summary(conv1) + + # pool1 + pool1 = tf.nn.max_pool(conv1, ksize=[1, 3, 3, 1], strides=[1, 2, 2, 1], + padding='SAME', name='pool1') + # norm1 + norm1 = tf.nn.lrn(pool1, 4, bias=1.0, alpha=0.001 / 9.0, beta=0.75, + name='norm1') + + # conv2 + with tf.variable_scope('conv2') as scope: + kernel = _variable_with_weight_decay('weights', + shape=[5, 5, 64, 64], + stddev=5e-2, + wd=0.0) + conv = tf.nn.conv2d( + norm1, pruning.apply_mask(kernel, scope), [1, 1, 1, 1], padding='SAME') + biases = _variable_on_cpu('biases', [64], tf.constant_initializer(0.1)) + pre_activation = tf.nn.bias_add(conv, biases) + conv2 = tf.nn.relu(pre_activation, name=scope.name) + _activation_summary(conv2) + + # norm2 + norm2 = tf.nn.lrn(conv2, 4, bias=1.0, alpha=0.001 / 9.0, beta=0.75, + name='norm2') + # pool2 + pool2 = tf.nn.max_pool(norm2, ksize=[1, 3, 3, 1], + strides=[1, 2, 2, 1], padding='SAME', name='pool2') + + # local3 + with tf.variable_scope('local3') as scope: + # Move everything into depth so we can perform a single matrix multiply. + reshape = tf.reshape(pool2, [BATCH_SIZE, -1]) + dim = reshape.get_shape()[1].value + weights = _variable_with_weight_decay('weights', shape=[dim, 384], + stddev=0.04, wd=0.004) + biases = _variable_on_cpu('biases', [384], tf.constant_initializer(0.1)) + local3 = tf.nn.relu( + tf.matmul(reshape, pruning.apply_mask(weights, scope)) + biases, + name=scope.name) + _activation_summary(local3) + + # local4 + with tf.variable_scope('local4') as scope: + weights = _variable_with_weight_decay('weights', shape=[384, 192], + stddev=0.04, wd=0.004) + biases = _variable_on_cpu('biases', [192], tf.constant_initializer(0.1)) + local4 = tf.nn.relu( + tf.matmul(local3, pruning.apply_mask(weights, scope)) + biases, + name=scope.name) + _activation_summary(local4) + + # linear layer(WX + b), + # We don't apply softmax here because + # tf.nn.sparse_softmax_cross_entropy_with_logits accepts the unscaled logits + # and performs the softmax internally for efficiency. + with tf.variable_scope('softmax_linear') as scope: + weights = _variable_with_weight_decay('weights', [192, NUM_CLASSES], + stddev=1/192.0, wd=0.0) + biases = _variable_on_cpu('biases', [NUM_CLASSES], + tf.constant_initializer(0.0)) + softmax_linear = tf.add( + tf.matmul(local4, pruning.apply_mask(weights, scope)), + biases, + name=scope.name) + _activation_summary(softmax_linear) + + return softmax_linear + + +def loss(logits, labels): + """Add L2Loss to all the trainable variables. + + Add summary for "Loss" and "Loss/avg". + Args: + logits: Logits from inference(). + labels: Labels from distorted_inputs or inputs(). 1-D tensor + of shape [batch_size] + + Returns: + Loss tensor of type float. + """ + # Calculate the average cross entropy loss across the batch. + labels = tf.cast(labels, tf.int64) + cross_entropy = tf.nn.sparse_softmax_cross_entropy_with_logits( + labels=labels, logits=logits, name='cross_entropy_per_example') + cross_entropy_mean = tf.reduce_mean(cross_entropy, name='cross_entropy') + tf.add_to_collection('losses', cross_entropy_mean) + + # The total loss is defined as the cross entropy loss plus all of the weight + # decay terms (L2 loss). + return tf.add_n(tf.get_collection('losses'), name='total_loss') + + +def _add_loss_summaries(total_loss): + """Add summaries for losses in CIFAR-10 model. + + Generates moving average for all losses and associated summaries for + visualizing the performance of the network. + + Args: + total_loss: Total loss from loss(). + Returns: + loss_averages_op: op for generating moving averages of losses. + """ + # Compute the moving average of all individual losses and the total loss. + loss_averages = tf.train.ExponentialMovingAverage(0.9, name='avg') + losses = tf.get_collection('losses') + loss_averages_op = loss_averages.apply(losses + [total_loss]) + + # Attach a scalar summary to all individual losses and the total loss; do the + # same for the averaged version of the losses. + for l in losses + [total_loss]: + # Name each loss as '(raw)' and name the moving average version of the loss + # as the original loss name. + tf.summary.scalar(l.op.name + ' (raw)', l) + tf.summary.scalar(l.op.name, loss_averages.average(l)) + + return loss_averages_op + + +def train(total_loss, global_step): + """Train CIFAR-10 model. + + Create an optimizer and apply to all trainable variables. Add moving + average for all trainable variables. + + Args: + total_loss: Total loss from loss(). + global_step: Integer Variable counting the number of training steps + processed. + Returns: + train_op: op for training. + """ + # Variables that affect learning rate. + num_batches_per_epoch = NUM_EXAMPLES_PER_EPOCH_FOR_TRAIN / BATCH_SIZE + decay_steps = int(num_batches_per_epoch * NUM_EPOCHS_PER_DECAY) + + # Decay the learning rate exponentially based on the number of steps. + lr = tf.train.exponential_decay(INITIAL_LEARNING_RATE, + global_step, + decay_steps, + LEARNING_RATE_DECAY_FACTOR, + staircase=True) + tf.summary.scalar('learning_rate', lr) + + # Generate moving averages of all losses and associated summaries. + loss_averages_op = _add_loss_summaries(total_loss) + + # Compute gradients. + with tf.control_dependencies([loss_averages_op]): + opt = tf.train.GradientDescentOptimizer(lr) + grads = opt.compute_gradients(total_loss) + + # Apply gradients. + apply_gradient_op = opt.apply_gradients(grads, global_step=global_step) + + # Add histograms for trainable variables. + for var in tf.trainable_variables(): + tf.summary.histogram(var.op.name, var) + + # Add histograms for gradients. + for grad, var in grads: + if grad is not None: + tf.summary.histogram(var.op.name + '/gradients', grad) + + # Track the moving averages of all trainable variables. + variable_averages = tf.train.ExponentialMovingAverage( + MOVING_AVERAGE_DECAY, global_step) + variables_averages_op = variable_averages.apply(tf.trainable_variables()) + + with tf.control_dependencies([apply_gradient_op, variables_averages_op]): + train_op = tf.no_op(name='train') + + return train_op + + +def maybe_download_and_extract(): + """Download and extract the tarball from Alex's website.""" + dest_directory = DATA_DIR + if not os.path.exists(dest_directory): + os.makedirs(dest_directory) + filename = DATA_URL.split('/')[-1] + filepath = os.path.join(dest_directory, filename) + if not os.path.exists(filepath): + def _progress(count, block_size, total_size): + sys.stdout.write('\r>> Downloading %s %.1f%%' % (filename, + float(count * block_size) / float(total_size) * 100.0)) + sys.stdout.flush() + filepath, _ = urllib.request.urlretrieve(DATA_URL, filepath, _progress) + print() + statinfo = os.stat(filepath) + print('Successfully downloaded', filename, statinfo.st_size, 'bytes.') + + tarfile.open(filepath, 'r:gz').extractall(dest_directory) diff --git a/tensorflow/contrib/model_pruning/examples/cifar10/cifar10_train.py b/tensorflow/contrib/model_pruning/examples/cifar10/cifar10_train.py new file mode 100644 index 0000000000..a1064a3b6a --- /dev/null +++ b/tensorflow/contrib/model_pruning/examples/cifar10/cifar10_train.py @@ -0,0 +1,159 @@ +# Copyright 2017 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""A binary to train pruned CIFAR-10 using a single GPU. + +Accuracy: +cifar10_train.py achieves ~86% accuracy after 100K steps (256 epochs of +data) as judged by cifar10_eval.py when target sparsity in +cifar10_pruning_spec.pbtxt is set to zero + +Results: +Sparsity | Accuracy after 150K steps +-------- | ------------------------- +0% | 86% +50% | 86% +75% | TODO(suyoggupta) +90% | TODO(suyoggupta) +95% | 77% + +Usage: +Please see the tutorial and website for how to download the CIFAR-10 +data set, compile the program and train the model. + + +""" +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import argparse +import datetime +import sys +import time + + +import tensorflow as tf + +from tensorflow.contrib.model_pruning.examples.cifar10 import cifar10_pruning as cifar10 +from tensorflow.contrib.model_pruning.python import pruning + +FLAGS = None + + +def train(): + """Train CIFAR-10 for a number of steps.""" + with tf.Graph().as_default(): + global_step = tf.contrib.framework.get_or_create_global_step() + + # Get images and labels for CIFAR-10. + images, labels = cifar10.distorted_inputs() + + # Build a Graph that computes the logits predictions from the + # inference model. + logits = cifar10.inference(images) + + # Calculate loss. + loss = cifar10.loss(logits, labels) + + # Build a Graph that trains the model with one batch of examples and + # updates the model parameters. + train_op = cifar10.train(loss, global_step) + + # Parse pruning hyperparameters + pruning_hparams = pruning.get_pruning_hparams().parse(FLAGS.pruning_hparams) + + # Create a pruning object using the pruning hyperparameters + pruning_obj = pruning.Pruning(pruning_hparams, global_step=global_step) + + # Use the pruning_obj to add ops to the training graph to update the masks + # The conditional_mask_update_op will update the masks only when the + # training step is in [begin_pruning_step, end_pruning_step] specified in + # the pruning spec proto + mask_update_op = pruning_obj.conditional_mask_update_op() + + # Use the pruning_obj to add summaries to the graph to track the sparsity + # of each of the layers + pruning_obj.add_pruning_summaries() + + class _LoggerHook(tf.train.SessionRunHook): + """Logs loss and runtime.""" + + def begin(self): + self._step = -1 + + def before_run(self, run_context): + self._step += 1 + self._start_time = time.time() + return tf.train.SessionRunArgs(loss) # Asks for loss value. + + def after_run(self, run_context, run_values): + duration = time.time() - self._start_time + loss_value = run_values.results + if self._step % 10 == 0: + num_examples_per_step = 128 + examples_per_sec = num_examples_per_step / duration + sec_per_batch = float(duration) + + format_str = ('%s: step %d, loss = %.2f (%.1f examples/sec; %.3f ' + 'sec/batch)') + print(format_str % (datetime.datetime.now(), self._step, loss_value, + examples_per_sec, sec_per_batch)) + + with tf.train.MonitoredTrainingSession( + checkpoint_dir=FLAGS.train_dir, + hooks=[tf.train.StopAtStepHook(last_step=FLAGS.max_steps), + tf.train.NanTensorHook(loss), + _LoggerHook()], + config=tf.ConfigProto( + log_device_placement=FLAGS.log_device_placement)) as mon_sess: + while not mon_sess.should_stop(): + mon_sess.run(train_op) + # Update the masks + mon_sess.run(mask_update_op) + + +def main(argv=None): # pylint: disable=unused-argument + cifar10.maybe_download_and_extract() + if tf.gfile.Exists(FLAGS.train_dir): + tf.gfile.DeleteRecursively(FLAGS.train_dir) + tf.gfile.MakeDirs(FLAGS.train_dir) + train() + + +if __name__ == '__main__': + parser = argparse.ArgumentParser() + parser.add_argument( + '--train_dir', + type=str, + default='/tmp/cifar10_train', + help='Directory where to write event logs and checkpoint.') + parser.add_argument( + '--pruning_hparams', + type=str, + default='', + help="""Comma separated list of pruning-related hyperparameters""") + parser.add_argument( + '--max_steps', + type=int, + default=1000000, + help='Number of batches to run.') + parser.add_argument( + '--log_device_placement', + type=bool, + default=False, + help='Whether to log device placement.') + + FLAGS, unparsed = parser.parse_known_args() + tf.app.run(main=main, argv=[sys.argv[0]] + unparsed) diff --git a/tensorflow/contrib/model_pruning/python/layers/core_layers.py b/tensorflow/contrib/model_pruning/python/layers/core_layers.py new file mode 100644 index 0000000000..ae60d8b1e1 --- /dev/null +++ b/tensorflow/contrib/model_pruning/python/layers/core_layers.py @@ -0,0 +1,477 @@ +# Copyright 2017 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Contains the core layer classes for model pruning and its functional aliases. +""" + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +from tensorflow.python.framework import ops +from tensorflow.python.framework import tensor_shape +from tensorflow.python.layers import base +from tensorflow.python.layers import utils +from tensorflow.python.ops import array_ops +from tensorflow.python.ops import init_ops +from tensorflow.python.ops import math_ops +from tensorflow.python.ops import nn +from tensorflow.python.ops import standard_ops + +MASK_COLLECTION = 'masks' +THRESHOLD_COLLECTION = 'thresholds' +MASKED_WEIGHT_COLLECTION = 'masked_weights' +WEIGHT_COLLECTION = 'kernel' +# The 'weights' part of the name is needed for the quantization library +# to recognize that the kernel should be quantized. +MASKED_WEIGHT_NAME = 'weights/masked_weight' + + +class _MaskedConv(base.Layer): + """Abstract nD convolution layer (private, used as implementation base). + + This layer creates a convolution kernel that is convolved + (actually cross-correlated) with the layer input to produce a tensor of + outputs. The weight tensor of this layer is masked. + If `use_bias` is True (and a `bias_initializer` is provided), + a bias vector is created and added to the outputs. Finally, if + `activation` is not `None`, it is applied to the outputs as well. + + Arguments: + rank: An integer, the rank of the convolution, e.g. "2" for 2D convolution. + filters: Integer, the dimensionality of the output space (i.e. the number + of filters in the convolution). + kernel_size: An integer or tuple/list of n integers, specifying the + length of the convolution window. + strides: An integer or tuple/list of n integers, + specifying the stride length of the convolution. + Specifying any stride value != 1 is incompatible with specifying + any `dilation_rate` value != 1. + padding: One of `"valid"` or `"same"` (case-insensitive). + data_format: A string, one of `channels_last` (default) or `channels_first`. + The ordering of the dimensions in the inputs. + `channels_last` corresponds to inputs with shape + `(batch, ..., channels)` while `channels_first` corresponds to + inputs with shape `(batch, channels, ...)`. + dilation_rate: An integer or tuple/list of n integers, specifying + the dilation rate to use for dilated convolution. + Currently, specifying any `dilation_rate` value != 1 is + incompatible with specifying any `strides` value != 1. + activation: Activation function. Set it to None to maintain a + linear activation. + use_bias: Boolean, whether the layer uses a bias. + kernel_initializer: An initializer for the convolution kernel. + bias_initializer: An initializer for the bias vector. If None, no bias will + be applied. + kernel_regularizer: Optional regularizer for the convolution kernel. + bias_regularizer: Optional regularizer for the bias vector. + activity_regularizer: Regularizer function for the output. + trainable: Boolean, if `True` also add variables to the graph collection + `GraphKeys.TRAINABLE_VARIABLES` (see `tf.Variable`). + name: A string, the name of the layer. + """ + + def __init__(self, + rank, + filters, + kernel_size, + strides=1, + padding='valid', + data_format='channels_last', + dilation_rate=1, + activation=None, + use_bias=True, + kernel_initializer=None, + bias_initializer=init_ops.zeros_initializer(), + kernel_regularizer=None, + bias_regularizer=None, + activity_regularizer=None, + trainable=True, + name=None, + **kwargs): + super(_MaskedConv, self).__init__( + trainable=trainable, + name=name, + activity_regularizer=activity_regularizer, + **kwargs) + self.rank = rank + self.filters = filters + self.kernel_size = utils.normalize_tuple(kernel_size, rank, 'kernel_size') + self.strides = utils.normalize_tuple(strides, rank, 'strides') + self.padding = utils.normalize_padding(padding) + self.data_format = utils.normalize_data_format(data_format) + self.dilation_rate = utils.normalize_tuple(dilation_rate, rank, + 'dilation_rate') + self.activation = activation + self.use_bias = use_bias + self.kernel_initializer = kernel_initializer + self.bias_initializer = bias_initializer + self.kernel_regularizer = kernel_regularizer + self.bias_regularizer = bias_regularizer + self.input_spec = base.InputSpec(ndim=self.rank + 2) + + def build(self, input_shape): + input_shape = tensor_shape.TensorShape(input_shape) + channel_axis = 1 if self.data_format == 'channels_first' else -1 + if input_shape[channel_axis].value is None: + raise ValueError('The channel dimension of the inputs ' + 'should be defined. Found `None`.') + input_dim = input_shape[channel_axis].value + kernel_shape = self.kernel_size + (input_dim, self.filters) + self.mask = self.add_variable( + name='mask', + shape=kernel_shape, + initializer=init_ops.ones_initializer(), + trainable=False, + dtype=self.dtype) + + self.kernel = self.add_variable( + name='kernel', + shape=kernel_shape, + initializer=self.kernel_initializer, + regularizer=self.kernel_regularizer, + trainable=True, + dtype=self.dtype) + + self.threshold = self.add_variable( + name='threshold', + shape=[], + initializer=init_ops.zeros_initializer(), + trainable=False, + dtype=self.dtype) + + # Add masked_weights in the weights namescope so as to make it easier + # for the quantization library to add quant ops. + self.masked_kernel = math_ops.multiply(self.mask, self.kernel, + MASKED_WEIGHT_NAME) + + ops.add_to_collection(MASK_COLLECTION, self.mask) + ops.add_to_collection(MASKED_WEIGHT_COLLECTION, self.masked_kernel) + ops.add_to_collection(THRESHOLD_COLLECTION, self.threshold) + ops.add_to_collection(WEIGHT_COLLECTION, self.kernel) + + if self.use_bias: + self.bias = self.add_variable( + name='bias', + shape=(self.filters,), + initializer=self.bias_initializer, + regularizer=self.bias_regularizer, + trainable=True, + dtype=self.dtype) + else: + self.bias = None + self.input_spec = base.InputSpec( + ndim=self.rank + 2, axes={channel_axis: input_dim}) + self.built = True + + def call(self, inputs): + outputs = nn.convolution( + input=inputs, + filter=self.masked_kernel, + dilation_rate=self.dilation_rate, + strides=self.strides, + padding=self.padding.upper(), + data_format=utils.convert_data_format(self.data_format, self.rank + 2)) + + if self.bias is not None: + if self.data_format == 'channels_first': + if self.rank == 1: + # nn.bias_add does not accept a 1D input tensor. + bias = array_ops.reshape(self.bias, (1, self.filters, 1)) + outputs += bias + if self.rank == 2: + outputs = nn.bias_add(outputs, self.bias, data_format='NCHW') + if self.rank == 3: + # As of Mar 2017, direct addition is significantly slower than + # bias_add when computing gradients. To use bias_add, we collapse Z + # and Y into a single dimension to obtain a 4D input tensor. + outputs_shape = outputs.shape.as_list() + outputs_4d = array_ops.reshape(outputs, [ + outputs_shape[0], outputs_shape[1], + outputs_shape[2] * outputs_shape[3], outputs_shape[4] + ]) + outputs_4d = nn.bias_add(outputs_4d, self.bias, data_format='NCHW') + outputs = array_ops.reshape(outputs_4d, outputs_shape) + else: + outputs = nn.bias_add(outputs, self.bias, data_format='NHWC') + + if self.activation is not None: + return self.activation(outputs) + return outputs + + def _compute_output_shape(self, input_shape): + input_shape = tensor_shape.TensorShape(input_shape).as_list() + if self.data_format == 'channels_last': + space = input_shape[1:-1] + new_space = [] + for i in range(len(space)): + new_dim = utils.conv_output_length( + space[i], + self.kernel_size[i], + padding=self.padding, + stride=self.strides[i], + dilation=self.dilation_rate[i]) + new_space.append(new_dim) + return tensor_shape.TensorShape([input_shape[0]] + new_space + + [self.filters]) + else: + space = input_shape[2:] + new_space = [] + for i in range(len(space)): + new_dim = utils.conv_output_length( + space[i], + self.kernel_size[i], + padding=self.padding, + stride=self.strides[i], + dilation=self.dilation_rate[i]) + new_space.append(new_dim) + return tensor_shape.TensorShape([input_shape[0], self.filters] + + new_space) + + +class MaskedConv2D(_MaskedConv): + """2D convolution layer (e.g. spatial convolution over images). + + This layer creates a convolution kernel that is convolved + (actually cross-correlated) with the layer input to produce a tensor of + outputs. If `use_bias` is True (and a `bias_initializer` is provided), + a bias vector is created and added to the outputs. Finally, if + `activation` is not `None`, it is applied to the outputs as well. + + Arguments: + filters: Integer, the dimensionality of the output space (i.e. the number + of filters in the convolution). + kernel_size: An integer or tuple/list of 2 integers, specifying the + height and width of the 2D convolution window. + Can be a single integer to specify the same value for + all spatial dimensions. + strides: An integer or tuple/list of 2 integers, + specifying the strides of the convolution along the height and width. + Can be a single integer to specify the same value for + all spatial dimensions. + Specifying any stride value != 1 is incompatible with specifying + any `dilation_rate` value != 1. + padding: One of `"valid"` or `"same"` (case-insensitive). + data_format: A string, one of `channels_last` (default) or `channels_first`. + The ordering of the dimensions in the inputs. + `channels_last` corresponds to inputs with shape + `(batch, height, width, channels)` while `channels_first` corresponds to + inputs with shape `(batch, channels, height, width)`. + + dilation_rate: An integer or tuple/list of 2 integers, specifying + the dilation rate to use for dilated convolution. + Can be a single integer to specify the same value for + all spatial dimensions. + Currently, specifying any `dilation_rate` value != 1 is + incompatible with specifying any stride value != 1. + activation: Activation function. Set it to None to maintain a + linear activation. + use_bias: Boolean, whether the layer uses a bias. + kernel_initializer: An initializer for the convolution kernel. + bias_initializer: An initializer for the bias vector. If None, no bias will + be applied. + kernel_regularizer: Optional regularizer for the convolution kernel. + bias_regularizer: Optional regularizer for the bias vector. + activity_regularizer: Regularizer function for the output. + trainable: Boolean, if `True` also add variables to the graph collection + `GraphKeys.TRAINABLE_VARIABLES` (see `tf.Variable`). + name: A string, the name of the layer. + """ + + def __init__(self, + filters, + kernel_size, + strides=(1, 1), + padding='valid', + data_format='channels_last', + dilation_rate=(1, 1), + activation=None, + use_bias=True, + kernel_initializer=None, + bias_initializer=init_ops.zeros_initializer(), + kernel_regularizer=None, + bias_regularizer=None, + activity_regularizer=None, + trainable=True, + name=None, + **kwargs): + super(MaskedConv2D, self).__init__( + rank=2, + filters=filters, + kernel_size=kernel_size, + strides=strides, + padding=padding, + data_format=data_format, + dilation_rate=dilation_rate, + activation=activation, + use_bias=use_bias, + kernel_initializer=kernel_initializer, + bias_initializer=bias_initializer, + kernel_regularizer=kernel_regularizer, + bias_regularizer=bias_regularizer, + activity_regularizer=activity_regularizer, + trainable=trainable, + name=name, + **kwargs) + + +class MaskedFullyConnected(base.Layer): + """Fully-connected layer class with masked weights. + + This layer implements the operation: + `outputs = activation(inputs.kernel + bias)` + Where `activation` is the activation function passed as the `activation` + argument (if not `None`), `kernel` is a weights matrix created by the layer, + and `bias` is a bias vector created by the layer + (only if `use_bias` is `True`). + + Note: if the input to the layer has a rank greater than 2, then it is + flattened prior to the initial matrix multiply by `kernel`. + + Arguments: + units: Integer or Long, dimensionality of the output space. + activation: Activation function (callable). Set it to None to maintain a + linear activation. + use_bias: Boolean, whether the layer uses a bias. + kernel_initializer: Initializer function for the weight matrix. + bias_initializer: Initializer function for the bias. + kernel_regularizer: Regularizer function for the weight matrix. + bias_regularizer: Regularizer function for the bias. + activity_regularizer: Regularizer function for the output. + trainable: Boolean, if `True` also add variables to the graph collection + `GraphKeys.TRAINABLE_VARIABLES` (see `tf.Variable`). + name: String, the name of the layer. Layers with the same name will + share weights, but to avoid mistakes we require reuse=True in such cases. + reuse: Boolean, whether to reuse the weights of a previous layer + by the same name. + + Properties: + units: Python integer, dimensionality of the output space. + activation: Activation function (callable). + use_bias: Boolean, whether the layer uses a bias. + kernel_initializer: Initializer instance (or name) for the weight matrix. + bias_initializer: Initializer instance (or name) for the bias. + kernel_regularizer: Regularizer instance for the weight matrix (callable) + bias_regularizer: Regularizer instance for the bias (callable). + activity_regularizer: Regularizer instance for the output (callable) + kernel: Weight matrix (TensorFlow variable or tensor). + bias: Bias vector, if applicable (TensorFlow variable or tensor). + """ + + def __init__(self, + units, + activation=None, + use_bias=True, + kernel_initializer=None, + bias_initializer=init_ops.zeros_initializer(), + kernel_regularizer=None, + bias_regularizer=None, + activity_regularizer=None, + trainable=True, + name=None, + **kwargs): + super(MaskedFullyConnected, self).__init__( + trainable=trainable, + name=name, + activity_regularizer=activity_regularizer, + **kwargs) + self.units = units + self.activation = activation + self.use_bias = use_bias + self.kernel_initializer = kernel_initializer + self.bias_initializer = bias_initializer + self.kernel_regularizer = kernel_regularizer + self.bias_regularizer = bias_regularizer + self.input_spec = base.InputSpec(min_ndim=2) + + def build(self, input_shape): + input_shape = tensor_shape.TensorShape(input_shape) + if input_shape[-1].value is None: + raise ValueError('The last dimension of the inputs to `Dense` ' + 'should be defined. Found `None`.') + self.input_spec = base.InputSpec( + min_ndim=2, axes={-1: input_shape[-1].value}) + + self.kernel = self.add_variable( + 'kernel', + shape=[input_shape[-1].value, self.units], + initializer=self.kernel_initializer, + regularizer=self.kernel_regularizer, + dtype=self.dtype, + trainable=True) + + self.mask = self.add_variable( + name='mask', + shape=[input_shape[-1].value, self.units], + initializer=init_ops.ones_initializer(), + trainable=False, + dtype=self.dtype) + + self.threshold = self.add_variable( + name='threshold', + shape=[], + initializer=init_ops.zeros_initializer(), + trainable=False, + dtype=self.dtype) + + # Add masked_weights in the weights namescope so as to make it easier + # for the quantization library to add quant ops. + self.masked_kernel = math_ops.multiply(self.mask, self.kernel, + MASKED_WEIGHT_NAME) + + ops.add_to_collection(MASK_COLLECTION, self.mask) + ops.add_to_collection(MASKED_WEIGHT_COLLECTION, self.masked_kernel) + ops.add_to_collection(THRESHOLD_COLLECTION, self.threshold) + ops.add_to_collection(WEIGHT_COLLECTION, self.kernel) + + if self.use_bias: + self.bias = self.add_variable( + 'bias', + shape=[ + self.units, + ], + initializer=self.bias_initializer, + regularizer=self.bias_regularizer, + dtype=self.dtype, + trainable=True) + else: + self.bias = None + self.built = True + + def call(self, inputs): + inputs = ops.convert_to_tensor(inputs, dtype=self.dtype) + shape = inputs.get_shape().as_list() + output_shape = shape[:-1] + [self.units] + if len(output_shape) > 2: + # Broadcasting is required for the inputs. + outputs = standard_ops.tensordot(inputs, self.masked_kernel, + [[len(shape) - 1], [0]]) + # Reshape the output back to the original ndim of the input. + outputs.set_shape(output_shape) + else: + outputs = standard_ops.matmul(inputs, self.masked_kernel) + if self.use_bias: + outputs = nn.bias_add(outputs, self.bias) + if self.activation is not None: + return self.activation(outputs) # pylint: disable=not-callable + return outputs + + def _compute_output_shape(self, input_shape): + input_shape = tensor_shape.TensorShape(input_shape) + input_shape = input_shape.with_rank_at_least(2) + if input_shape[-1].value is None: + raise ValueError( + 'The innermost dimension of input_shape must be defined, but saw: %s' + % input_shape) + return input_shape[:-1].concatenate(self.units) diff --git a/tensorflow/contrib/model_pruning/python/layers/layers.py b/tensorflow/contrib/model_pruning/python/layers/layers.py new file mode 100644 index 0000000000..dfebb9a679 --- /dev/null +++ b/tensorflow/contrib/model_pruning/python/layers/layers.py @@ -0,0 +1,364 @@ +# Copyright 2017 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Tensorflow layers with added variables for parameter masking. + +Branched from tensorflow/contrib/layers/python/layers/layers.py +""" +# pylint: disable=missing-docstring +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import numpy as np +import six + +from tensorflow.contrib.framework.python.ops import add_arg_scope +from tensorflow.contrib.framework.python.ops import variables +from tensorflow.contrib.layers.python.layers import initializers +from tensorflow.contrib.layers.python.layers import utils +from tensorflow.contrib.model_pruning.python.layers import core_layers as core +from tensorflow.python.framework import ops +from tensorflow.python.ops import init_ops +from tensorflow.python.ops import nn +from tensorflow.python.ops import variable_scope +from tensorflow.python.ops import variables as tf_variables + + +def _model_variable_getter(getter, + name, + shape=None, + dtype=None, + initializer=None, + regularizer=None, + trainable=True, + collections=None, + caching_device=None, + partitioner=None, + rename=None, + use_resource=None, + **_): + """Getter that uses model_variable for compatibility with core layers.""" + short_name = name.split('/')[-1] + if rename and short_name in rename: + name_components = name.split('/') + name_components[-1] = rename[short_name] + name = '/'.join(name_components) + return variables.model_variable( + name, + shape=shape, + dtype=dtype, + initializer=initializer, + regularizer=regularizer, + collections=collections, + trainable=trainable, + caching_device=caching_device, + partitioner=partitioner, + custom_getter=getter, + use_resource=use_resource) + + +def _build_variable_getter(rename=None): + """Build a model variable getter that respects scope getter and renames.""" + + # VariableScope will nest the getters + def layer_variable_getter(getter, *args, **kwargs): + kwargs['rename'] = rename + return _model_variable_getter(getter, *args, **kwargs) + + return layer_variable_getter + + +def _add_variable_to_collections(variable, collections_set, collections_name): + """Adds variable (or all its parts) to all collections with that name.""" + collections = utils.get_variable_collections(collections_set, + collections_name) or [] + variables_list = [variable] + if isinstance(variable, tf_variables.PartitionedVariable): + variables_list = [v for v in variable] + for collection in collections: + for var in variables_list: + if var not in ops.get_collection(collection): + ops.add_to_collection(collection, var) + + +@add_arg_scope +def masked_convolution(inputs, + num_outputs, + kernel_size, + stride=1, + padding='SAME', + data_format=None, + rate=1, + activation_fn=nn.relu, + normalizer_fn=None, + normalizer_params=None, + weights_initializer=initializers.xavier_initializer(), + weights_regularizer=None, + biases_initializer=init_ops.zeros_initializer(), + biases_regularizer=None, + reuse=None, + variables_collections=None, + outputs_collections=None, + trainable=True, + scope=None): + """Adds an 2D convolution followed by an optional batch_norm layer. + The layer creates a mask variable on top of the weight variable. The input to + the convolution operation is the elementwise multiplication of the mask + variable and the weigh + + It is required that 1 <= N <= 3. + + `convolution` creates a variable called `weights`, representing the + convolutional kernel, that is convolved (actually cross-correlated) with the + `inputs` to produce a `Tensor` of activations. If a `normalizer_fn` is + provided (such as `batch_norm`), it is then applied. Otherwise, if + `normalizer_fn` is None and a `biases_initializer` is provided then a `biases` + variable would be created and added the activations. Finally, if + `activation_fn` is not `None`, it is applied to the activations as well. + + Performs atrous convolution with input stride/dilation rate equal to `rate` + if a value > 1 for any dimension of `rate` is specified. In this case + `stride` values != 1 are not supported. + + Args: + inputs: A Tensor of rank N+2 of shape + `[batch_size] + input_spatial_shape + [in_channels]` if data_format does + not start with "NC" (default), or + `[batch_size, in_channels] + input_spatial_shape` if data_format starts + with "NC". + num_outputs: Integer, the number of output filters. + kernel_size: A sequence of N positive integers specifying the spatial + dimensions of of the filters. Can be a single integer to specify the same + value for all spatial dimensions. + stride: A sequence of N positive integers specifying the stride at which to + compute output. Can be a single integer to specify the same value for all + spatial dimensions. Specifying any `stride` value != 1 is incompatible + with specifying any `rate` value != 1. + padding: One of `"VALID"` or `"SAME"`. + data_format: A string or None. Specifies whether the channel dimension of + the `input` and output is the last dimension (default, or if `data_format` + does not start with "NC"), or the second dimension (if `data_format` + starts with "NC"). For N=1, the valid values are "NWC" (default) and + "NCW". For N=2, the valid values are "NHWC" (default) and "NCHW". + For N=3, the valid values are "NDHWC" (default) and "NCDHW". + rate: A sequence of N positive integers specifying the dilation rate to use + for atrous convolution. Can be a single integer to specify the same + value for all spatial dimensions. Specifying any `rate` value != 1 is + incompatible with specifying any `stride` value != 1. + activation_fn: Activation function. The default value is a ReLU function. + Explicitly set it to None to skip it and maintain a linear activation. + normalizer_fn: Normalization function to use instead of `biases`. If + `normalizer_fn` is provided then `biases_initializer` and + `biases_regularizer` are ignored and `biases` are not created nor added. + default set to None for no normalizer function + normalizer_params: Normalization function parameters. + weights_initializer: An initializer for the weights. + weights_regularizer: Optional regularizer for the weights. + biases_initializer: An initializer for the biases. If None skip biases. + biases_regularizer: Optional regularizer for the biases. + reuse: Whether or not the layer and its variables should be reused. To be + able to reuse the layer scope must be given. + variables_collections: Optional list of collections for all the variables or + a dictionary containing a different list of collection per variable. + outputs_collections: Collection to add the outputs. + trainable: If `True` also add variables to the graph collection + `GraphKeys.TRAINABLE_VARIABLES` (see tf.Variable). + scope: Optional scope for `variable_scope`. + + Returns: + A tensor representing the output of the operation. + + Raises: + ValueError: If `data_format` is invalid. + ValueError: Both 'rate' and `stride` are not uniformly 1. + """ + if data_format not in [None, 'NWC', 'NCW', 'NHWC', 'NCHW', 'NDHWC', 'NCDHW']: + raise ValueError('Invalid data_format: %r' % (data_format,)) + + layer_variable_getter = _build_variable_getter({ + 'bias': 'biases', + 'kernel': 'weights' + }) + + with variable_scope.variable_scope( + scope, 'Conv', [inputs], reuse=reuse, + custom_getter=layer_variable_getter) as sc: + inputs = ops.convert_to_tensor(inputs) + input_rank = inputs.get_shape().ndims + + if input_rank == 3: + raise ValueError('Sparse Convolution not supported for input with rank', + input_rank) + elif input_rank == 4: + layer_class = core.MaskedConv2D + elif input_rank == 5: + raise ValueError('Sparse Convolution not supported for input with rank', + input_rank) + else: + raise ValueError('Sparse Convolution not supported for input with rank', + input_rank) + + if data_format is None or data_format == 'NHWC': + df = 'channels_last' + elif data_format == 'NCHW': + df = 'channels_first' + else: + raise ValueError('Unsupported data fromat', data_format) + + layer = layer_class( + filters=num_outputs, + kernel_size=kernel_size, + strides=stride, + padding=padding, + data_format=df, + dilation_rate=rate, + activation=None, + use_bias=not normalizer_fn and biases_initializer, + kernel_initializer=weights_initializer, + bias_initializer=biases_initializer, + kernel_regularizer=weights_regularizer, + bias_regularizer=biases_regularizer, + activity_regularizer=None, + trainable=trainable, + name=sc.name, + dtype=inputs.dtype.base_dtype, + _scope=sc, + _reuse=reuse) + outputs = layer.apply(inputs) + + # Add variables to collections. + _add_variable_to_collections(layer.kernel, variables_collections, 'weights') + if layer.use_bias: + _add_variable_to_collections(layer.bias, variables_collections, 'biases') + + if normalizer_fn is not None: + normalizer_params = normalizer_params or {} + outputs = normalizer_fn(outputs, **normalizer_params) + + if activation_fn is not None: + outputs = activation_fn(outputs) + return utils.collect_named_outputs(outputs_collections, + sc.original_name_scope, outputs) + + +masked_conv2d = masked_convolution + + +@add_arg_scope +def masked_fully_connected( + inputs, + num_outputs, + activation_fn=nn.relu, + normalizer_fn=None, + normalizer_params=None, + weights_initializer=initializers.xavier_initializer(), + weights_regularizer=None, + biases_initializer=init_ops.zeros_initializer(), + biases_regularizer=None, + reuse=None, + variables_collections=None, + outputs_collections=None, + trainable=True, + scope=None): + """Adds a sparse fully connected layer. The weight matrix is masked. + + `fully_connected` creates a variable called `weights`, representing a fully + connected weight matrix, which is multiplied by the `inputs` to produce a + `Tensor` of hidden units. If a `normalizer_fn` is provided (such as + `batch_norm`), it is then applied. Otherwise, if `normalizer_fn` is + None and a `biases_initializer` is provided then a `biases` variable would be + created and added the hidden units. Finally, if `activation_fn` is not `None`, + it is applied to the hidden units as well. + + Note: that if `inputs` have a rank greater than 2, then `inputs` is flattened + prior to the initial matrix multiply by `weights`. + + Args: + inputs: A tensor of at least rank 2 and static value for the last dimension; + i.e. `[batch_size, depth]`, `[None, None, None, channels]`. + num_outputs: Integer or long, the number of output units in the layer. + activation_fn: Activation function. The default value is a ReLU function. + Explicitly set it to None to skip it and maintain a linear activation. + normalizer_fn: Normalization function to use instead of `biases`. If + `normalizer_fn` is provided then `biases_initializer` and + `biases_regularizer` are ignored and `biases` are not created nor added. + default set to None for no normalizer function + normalizer_params: Normalization function parameters. + weights_initializer: An initializer for the weights. + weights_regularizer: Optional regularizer for the weights. + biases_initializer: An initializer for the biases. If None skip biases. + biases_regularizer: Optional regularizer for the biases. + reuse: Whether or not the layer and its variables should be reused. To be + able to reuse the layer scope must be given. + variables_collections: Optional list of collections for all the variables or + a dictionary containing a different list of collections per variable. + outputs_collections: Collection to add the outputs. + trainable: If `True` also add variables to the graph collection + `GraphKeys.TRAINABLE_VARIABLES` (see tf.Variable). + scope: Optional scope for variable_scope. + + Returns: + The tensor variable representing the result of the series of operations. + + Raises: + ValueError: If x has rank less than 2 or if its last dimension is not set. + """ + if not isinstance(num_outputs, six.integer_types): + raise ValueError('num_outputs should be int or long, got %s.' % + (num_outputs,)) + + layer_variable_getter = _build_variable_getter({ + 'bias': 'biases', + 'kernel': 'weights' + }) + + with variable_scope.variable_scope( + scope, + 'fully_connected', [inputs], + reuse=reuse, + custom_getter=layer_variable_getter) as sc: + inputs = ops.convert_to_tensor(inputs) + layer = core.MaskedFullyConnected( + units=num_outputs, + activation=None, + use_bias=not normalizer_fn and biases_initializer, + kernel_initializer=weights_initializer, + bias_initializer=biases_initializer, + kernel_regularizer=weights_regularizer, + bias_regularizer=biases_regularizer, + activity_regularizer=None, + trainable=trainable, + name=sc.name, + dtype=inputs.dtype.base_dtype, + _scope=sc, + _reuse=reuse) + outputs = layer.apply(inputs) + + # Add variables to collections. + _add_variable_to_collections(layer.kernel, variables_collections, 'weights') + if layer.bias is not None: + _add_variable_to_collections(layer.bias, variables_collections, 'biases') + + # Apply normalizer function / layer. + if normalizer_fn is not None: + if not normalizer_params: + normalizer_params = {} + outputs = normalizer_fn(outputs, **normalizer_params) + + if activation_fn is not None: + outputs = activation_fn(outputs) + + return utils.collect_named_outputs(outputs_collections, + sc.original_name_scope, outputs) diff --git a/tensorflow/contrib/model_pruning/python/layers/layers_test.py b/tensorflow/contrib/model_pruning/python/layers/layers_test.py new file mode 100644 index 0000000000..97a2c97850 --- /dev/null +++ b/tensorflow/contrib/model_pruning/python/layers/layers_test.py @@ -0,0 +1,139 @@ +# Copyright 2017 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Tests for imagingvision.intelligence.tensorflow.model_pruning.layers.""" + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +from tensorflow.contrib.model_pruning.python.layers import core_layers +from tensorflow.contrib.model_pruning.python.layers import layers +from tensorflow.python.framework import ops +from tensorflow.python.ops import array_ops +from tensorflow.python.platform import test + + +class MaskedConvolutionLayerTest(test.TestCase): + + def setUp(self): + super(MaskedConvolutionLayerTest, self).setUp() + self.height, self.width = 7, 9 + + def testInvalidRank3(self): + input_tensor = array_ops.ones((self.height, self.width, 3)) + with self.assertRaisesRegexp(ValueError, 'rank'): + layers.masked_conv2d(input_tensor, 32, 3) + + def testInvalidRank5(self): + input_tensor = array_ops.ones((8, 8, self.height, self.width, 3)) + with self.assertRaisesRegexp(ValueError, 'rank'): + layers.masked_conv2d(input_tensor, 32, 3) + + def testSingleConvMaskAdded(self): + kernel_size = 3 + input_depth, output_depth = 8, 32 + input_tensor = array_ops.ones((8, self.height, self.width, input_depth)) + layers.masked_conv2d(input_tensor, output_depth, kernel_size) + + masks = ops.get_collection(core_layers.MASK_COLLECTION) + self.assertEqual(len(masks), 1) + self.assertListEqual(masks[0].get_shape().as_list(), + [kernel_size, kernel_size, input_depth, output_depth]) + + masked_weight = ops.get_collection(core_layers.MASKED_WEIGHT_COLLECTION) + self.assertEqual(len(masked_weight), 1) + self.assertListEqual(masked_weight[0].get_shape().as_list(), + [kernel_size, kernel_size, input_depth, output_depth]) + + def testMultipleConvMaskAdded(self): + number_of_layers = 5 + + kernel_size = 3 + base_depth = 4 + depth_step = 7 + + input_tensor = array_ops.ones((8, self.height, self.width, base_depth)) + + top_layer = input_tensor + + for ix in range(number_of_layers): + top_layer = layers.masked_conv2d(top_layer, base_depth + + (ix + 1) * depth_step, kernel_size) + + masks = ops.get_collection(core_layers.MASK_COLLECTION) + self.assertEqual(len(masks), number_of_layers) + for ix in range(number_of_layers): + self.assertListEqual(masks[ix].get_shape().as_list(), [ + kernel_size, kernel_size, base_depth + ix * depth_step, + base_depth + (ix + 1) * depth_step + ]) + + masked_weight = ops.get_collection(core_layers.MASKED_WEIGHT_COLLECTION) + self.assertEqual(len(masked_weight), number_of_layers) + for ix in range(number_of_layers): + self.assertListEqual(masked_weight[ix].get_shape().as_list(), [ + kernel_size, kernel_size, base_depth + ix * depth_step, + base_depth + (ix + 1) * depth_step + ]) + + +class MaskedFullyConnectedLayerTest(test.TestCase): + + def testSingleFCMaskAdded(self): + input_depth, output_depth = 8, 32 + input_tensor = array_ops.ones((5, input_depth)) + layers.masked_fully_connected(input_tensor, output_depth) + + masks = ops.get_collection(core_layers.MASK_COLLECTION) + self.assertEqual(len(masks), 1) + self.assertListEqual(masks[0].get_shape().as_list(), + [input_depth, output_depth]) + + masked_weight = ops.get_collection(core_layers.MASKED_WEIGHT_COLLECTION) + self.assertEqual(len(masked_weight), 1) + self.assertListEqual(masked_weight[0].get_shape().as_list(), + [input_depth, output_depth]) + + def testMultipleConvMaskAdded(self): + number_of_layers = 5 + + base_depth = 4 + depth_step = 7 + + input_tensor = array_ops.ones((8, base_depth)) + + top_layer = input_tensor + + for ix in range(number_of_layers): + top_layer = layers.masked_fully_connected(top_layer, base_depth + + (ix + 1) * depth_step) + + masks = ops.get_collection(core_layers.MASK_COLLECTION) + self.assertEqual(len(masks), number_of_layers) + for ix in range(number_of_layers): + self.assertListEqual(masks[ix].get_shape().as_list(), [ + base_depth + ix * depth_step, base_depth + (ix + 1) * depth_step + ]) + + masked_weight = ops.get_collection(core_layers.MASKED_WEIGHT_COLLECTION) + self.assertEqual(len(masked_weight), number_of_layers) + for ix in range(number_of_layers): + self.assertListEqual(masked_weight[ix].get_shape().as_list(), [ + base_depth + ix * depth_step, base_depth + (ix + 1) * depth_step + ]) + + +if __name__ == '__main__': + test.main() diff --git a/tensorflow/contrib/model_pruning/python/layers/rnn_cells.py b/tensorflow/contrib/model_pruning/python/layers/rnn_cells.py new file mode 100644 index 0000000000..18ba3d1327 --- /dev/null +++ b/tensorflow/contrib/model_pruning/python/layers/rnn_cells.py @@ -0,0 +1,340 @@ +# Copyright 2017 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Module implementing RNN Cells with pruning. + +This module implements BasicLSTMCell and LSTMCell with pruning. +Code adapted from third_party/tensorflow/python/ops/rnn_cell_impl.py +""" + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +from tensorflow.contrib.model_pruning.python.layers import core_layers +from tensorflow.python.framework import constant_op +from tensorflow.python.framework import dtypes +from tensorflow.python.framework import ops +from tensorflow.python.ops import array_ops +from tensorflow.python.ops import clip_ops +from tensorflow.python.ops import init_ops +from tensorflow.python.ops import math_ops +from tensorflow.python.ops import nn_ops +from tensorflow.python.ops import rnn_cell as tf_rnn + + +class MaskedBasicLSTMCell(tf_rnn.BasicLSTMCell): + """Basic LSTM recurrent network cell with pruning. + + Overrides the call method of tensorflow BasicLSTMCell and injects the weight + masks + + The implementation is based on: http://arxiv.org/abs/1409.2329. + + We add forget_bias (default: 1) to the biases of the forget gate in order to + reduce the scale of forgetting in the beginning of the training. + + It does not allow cell clipping, a projection layer, and does not + use peep-hole connections: it is the basic baseline. + + For advanced models, please use the full @{tf.nn.rnn_cell.LSTMCell} + that follows. + """ + + def __init__(self, + num_units, + forget_bias=1.0, + state_is_tuple=True, + activation=None, + reuse=None, + name=None): + """Initialize the basic LSTM cell with pruning. + + Args: + num_units: int, The number of units in the LSTM cell. + forget_bias: float, The bias added to forget gates (see above). + Must set to `0.0` manually when restoring from CudnnLSTM-trained + checkpoints. + state_is_tuple: If True, accepted and returned states are 2-tuples of + the `c_state` and `m_state`. If False, they are concatenated + along the column axis. The latter behavior will soon be deprecated. + activation: Activation function of the inner states. Default: `tanh`. + reuse: (optional) Python boolean describing whether to reuse variables + in an existing scope. If not `True`, and the existing scope already has + the given variables, an error is raised. + name: String, the name of the layer. Layers with the same name will + share weights, but to avoid mistakes we require reuse=True in such + cases. + + When restoring from CudnnLSTM-trained checkpoints, must use + CudnnCompatibleLSTMCell instead. + """ + super(MaskedBasicLSTMCell, self).__init__( + num_units, + forget_bias=forget_bias, + state_is_tuple=state_is_tuple, + activation=activation, + reuse=reuse, + name=name) + + def build(self, inputs_shape): + # Call the build method of the parent class. + super(MaskedBasicLSTMCell, self).build(inputs_shape) + + input_depth = inputs_shape[1].value + h_depth = self._num_units + self._mask = self.add_variable( + name="mask", + shape=[input_depth + h_depth, 4 * h_depth], + initializer=init_ops.ones_initializer(), + trainable=False, + dtype=self.dtype) + self._threshold = self.add_variable( + name="threshold", + shape=[], + initializer=init_ops.zeros_initializer(), + trainable=False, + dtype=self.dtype) + # Add masked_weights in the weights namescope so as to make it easier + # for the quantization library to add quant ops. + self._masked_kernel = math_ops.multiply(self._mask, self._kernel, + core_layers.MASKED_WEIGHT_NAME) + if self._mask not in ops.get_collection_ref(core_layers.MASK_COLLECTION): + ops.add_to_collection(core_layers.MASK_COLLECTION, self._mask) + ops.add_to_collection(core_layers.MASKED_WEIGHT_COLLECTION, + self._masked_kernel) + ops.add_to_collection(core_layers.THRESHOLD_COLLECTION, self._threshold) + ops.add_to_collection(core_layers.WEIGHT_COLLECTION, self._kernel) + + def call(self, inputs, state): + """Long short-term memory cell (LSTM) with masks for pruning. + + Args: + inputs: `2-D` tensor with shape `[batch_size, input_size]`. + state: An `LSTMStateTuple` of state tensors, each shaped + `[batch_size, self.state_size]`, if `state_is_tuple` has been set to + `True`. Otherwise, a `Tensor` shaped + `[batch_size, 2 * self.state_size]`. + + Returns: + A pair containing the new hidden state, and the new state (either a + `LSTMStateTuple` or a concatenated state, depending on + `state_is_tuple`). + """ + sigmoid = math_ops.sigmoid + one = constant_op.constant(1, dtype=dtypes.int32) + # Parameters of gates are concatenated into one multiply for efficiency. + if self._state_is_tuple: + c, h = state + else: + c, h = array_ops.split(value=state, num_or_size_splits=2, axis=one) + + gate_inputs = math_ops.matmul( + array_ops.concat([inputs, h], 1), self._masked_kernel) + gate_inputs = nn_ops.bias_add(gate_inputs, self._bias) + + # i = input_gate, j = new_input, f = forget_gate, o = output_gate + i, j, f, o = array_ops.split( + value=gate_inputs, num_or_size_splits=4, axis=one) + + forget_bias_tensor = constant_op.constant(self._forget_bias, dtype=f.dtype) + # Note that using `add` and `multiply` instead of `+` and `*` gives a + # performance improvement. So using those at the cost of readability. + add = math_ops.add + multiply = math_ops.multiply + new_c = add( + multiply(c, sigmoid(add(f, forget_bias_tensor))), + multiply(sigmoid(i), self._activation(j))) + new_h = multiply(self._activation(new_c), sigmoid(o)) + + if self._state_is_tuple: + new_state = tf_rnn.LSTMStateTuple(new_c, new_h) + else: + new_state = array_ops.concat([new_c, new_h], 1) + return new_h, new_state + + +class MaskedLSTMCell(tf_rnn.LSTMCell): + """LSTMCell with pruning. + + Overrides the call method of tensorflow LSTMCell and injects the weight masks. + Masks are applied to only the weight matrix of the LSTM and not the + projection matrix. + """ + + def __init__(self, + num_units, + use_peepholes=False, + cell_clip=None, + initializer=None, + num_proj=None, + proj_clip=None, + num_unit_shards=None, + num_proj_shards=None, + forget_bias=1.0, + state_is_tuple=True, + activation=None, + reuse=None): + """Initialize the parameters for an LSTM cell with masks for pruning. + + Args: + num_units: int, The number of units in the LSTM cell + use_peepholes: bool, set True to enable diagonal/peephole connections. + cell_clip: (optional) A float value, if provided the cell state is clipped + by this value prior to the cell output activation. + initializer: (optional) The initializer to use for the weight and + projection matrices. + num_proj: (optional) int, The output dimensionality for the projection + matrices. If None, no projection is performed. + proj_clip: (optional) A float value. If `num_proj > 0` and `proj_clip` is + provided, then the projected values are clipped elementwise to within + `[-proj_clip, proj_clip]`. + num_unit_shards: Deprecated, will be removed by Jan. 2017. + Use a variable_scope partitioner instead. + num_proj_shards: Deprecated, will be removed by Jan. 2017. + Use a variable_scope partitioner instead. + forget_bias: Biases of the forget gate are initialized by default to 1 + in order to reduce the scale of forgetting at the beginning of + the training. Must set it manually to `0.0` when restoring from + CudnnLSTM trained checkpoints. + state_is_tuple: If True, accepted and returned states are 2-tuples of + the `c_state` and `m_state`. If False, they are concatenated + along the column axis. This latter behavior will soon be deprecated. + activation: Activation function of the inner states. Default: `tanh`. + reuse: (optional) Python boolean describing whether to reuse variables + in an existing scope. If not `True`, and the existing scope already has + the given variables, an error is raised. + + When restoring from CudnnLSTM-trained checkpoints, must use + CudnnCompatibleLSTMCell instead. + """ + super(MaskedLSTMCell, self).__init__( + num_units, + use_peepholes=use_peepholes, + cell_clip=cell_clip, + initializer=initializer, + num_proj=num_proj, + proj_clip=proj_clip, + num_unit_shards=num_unit_shards, + num_proj_shards=num_proj_shards, + forget_bias=forget_bias, + state_is_tuple=state_is_tuple, + activation=activation, + reuse=reuse) + + def build(self, inputs_shape): + # Call the build method of the parent class. + super(MaskedLSTMCell, self).build(inputs_shape) + + input_depth = inputs_shape[1].value + h_depth = self._num_units + self._mask = self.add_variable( + name="mask", + shape=[input_depth + h_depth, 4 * h_depth], + initializer=init_ops.ones_initializer(), + trainable=False, + dtype=self.dtype) + self._threshold = self.add_variable( + name="threshold", + shape=[], + initializer=init_ops.zeros_initializer(), + trainable=False, + dtype=self.dtype) + # Add masked_weights in the weights namescope so as to make it easier + # for the quantization library to add quant ops. + self._masked_kernel = math_ops.multiply(self._mask, self._kernel, + core_layers.MASKED_WEIGHT_NAME) + if self._mask not in ops.get_collection_ref(core_layers.MASK_COLLECTION): + ops.add_to_collection(core_layers.MASK_COLLECTION, self._mask) + ops.add_to_collection(core_layers.MASKED_WEIGHT_COLLECTION, + self._masked_kernel) + ops.add_to_collection(core_layers.THRESHOLD_COLLECTION, self._threshold) + ops.add_to_collection(core_layers.WEIGHT_COLLECTION, self._kernel) + + def call(self, inputs, state): + """Run one step of LSTM. + + Args: + inputs: input Tensor, 2D, `[batch, num_units]. + state: if `state_is_tuple` is False, this must be a state Tensor, + `2-D, [batch, state_size]`. If `state_is_tuple` is True, this must be a + tuple of state Tensors, both `2-D`, with column sizes `c_state` and + `m_state`. + + Returns: + A tuple containing: + + - A `2-D, [batch, output_dim]`, Tensor representing the output of the + LSTM after reading `inputs` when previous state was `state`. + Here output_dim is: + num_proj if num_proj was set, + num_units otherwise. + - Tensor(s) representing the new state of LSTM after reading `inputs` when + the previous state was `state`. Same type and shape(s) as `state`. + + Raises: + ValueError: If input size cannot be inferred from inputs via + static shape inference. + """ + num_proj = self._num_units if self._num_proj is None else self._num_proj + sigmoid = math_ops.sigmoid + + if self._state_is_tuple: + (c_prev, m_prev) = state + else: + c_prev = array_ops.slice(state, [0, 0], [-1, self._num_units]) + m_prev = array_ops.slice(state, [0, self._num_units], [-1, num_proj]) + + input_size = inputs.get_shape().with_rank(2)[1] + if input_size.value is None: + raise ValueError("Could not infer input size from inputs.get_shape()[-1]") + + # i = input_gate, j = new_input, f = forget_gate, o = output_gate + lstm_matrix = math_ops.matmul( + array_ops.concat([inputs, m_prev], 1), self._masked_kernel) + lstm_matrix = nn_ops.bias_add(lstm_matrix, self._bias) + + i, j, f, o = array_ops.split( + value=lstm_matrix, num_or_size_splits=4, axis=1) + # Diagonal connections + if self._use_peepholes: + c = ( + sigmoid(f + self._forget_bias + self._w_f_diag * c_prev) * c_prev + + sigmoid(i + self._w_i_diag * c_prev) * self._activation(j)) + else: + c = ( + sigmoid(f + self._forget_bias) * c_prev + + sigmoid(i) * self._activation(j)) + + if self._cell_clip is not None: + # pylint: disable=invalid-unary-operand-type + c = clip_ops.clip_by_value(c, -self._cell_clip, self._cell_clip) + # pylint: enable=invalid-unary-operand-type + if self._use_peepholes: + m = sigmoid(o + self._w_o_diag * c) * self._activation(c) + else: + m = sigmoid(o) * self._activation(c) + + if self._num_proj is not None: + m = math_ops.matmul(m, self._proj_kernel) + + if self._proj_clip is not None: + # pylint: disable=invalid-unary-operand-type + m = clip_ops.clip_by_value(m, -self._proj_clip, self._proj_clip) + # pylint: enable=invalid-unary-operand-type + + new_state = ( + tf_rnn.LSTMStateTuple(c, m) + if self._state_is_tuple else array_ops.concat([c, m], 1)) + return m, new_state diff --git a/tensorflow/contrib/model_pruning/python/layers/rnn_cells_test.py b/tensorflow/contrib/model_pruning/python/layers/rnn_cells_test.py new file mode 100644 index 0000000000..e85ae7b22a --- /dev/null +++ b/tensorflow/contrib/model_pruning/python/layers/rnn_cells_test.py @@ -0,0 +1,85 @@ +# Copyright 2017 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Tests for creating different number of masks in rnn_cells.""" + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +from tensorflow.contrib.model_pruning.python import pruning +from tensorflow.contrib.model_pruning.python.layers import rnn_cells +from tensorflow.python.ops import random_ops +from tensorflow.python.ops import rnn_cell as tf_rnn_cells +from tensorflow.python.ops import variables +from tensorflow.python.platform import test + + +class RnnCellsTest(test.TestCase): + + def setUp(self): + super(RnnCellsTest, self).setUp() + self.batch_size = 8 + self.dim = 10 + + def testMaskedBasicLSTMCell(self): + expected_num_masks = 1 + expected_num_rows = 2 * self.dim + expected_num_cols = 4 * self.dim + with self.test_session(): + inputs = variables.Variable( + random_ops.random_normal([self.batch_size, self.dim])) + c = variables.Variable( + random_ops.random_normal([self.batch_size, self.dim])) + h = variables.Variable( + random_ops.random_normal([self.batch_size, self.dim])) + state = tf_rnn_cells.LSTMStateTuple(c, h) + lstm_cell = rnn_cells.MaskedBasicLSTMCell(self.dim) + lstm_cell(inputs, state) + self.assertEqual(len(pruning.get_masks()), expected_num_masks) + self.assertEqual(len(pruning.get_masked_weights()), expected_num_masks) + self.assertEqual(len(pruning.get_thresholds()), expected_num_masks) + self.assertEqual(len(pruning.get_weights()), expected_num_masks) + + for mask in pruning.get_masks(): + self.assertEqual(mask.shape, (expected_num_rows, expected_num_cols)) + for weight in pruning.get_weights(): + self.assertEqual(weight.shape, (expected_num_rows, expected_num_cols)) + + def testMaskedLSTMCell(self): + expected_num_masks = 1 + expected_num_rows = 2 * self.dim + expected_num_cols = 4 * self.dim + with self.test_session(): + inputs = variables.Variable( + random_ops.random_normal([self.batch_size, self.dim])) + c = variables.Variable( + random_ops.random_normal([self.batch_size, self.dim])) + h = variables.Variable( + random_ops.random_normal([self.batch_size, self.dim])) + state = tf_rnn_cells.LSTMStateTuple(c, h) + lstm_cell = rnn_cells.MaskedLSTMCell(self.dim) + lstm_cell(inputs, state) + self.assertEqual(len(pruning.get_masks()), expected_num_masks) + self.assertEqual(len(pruning.get_masked_weights()), expected_num_masks) + self.assertEqual(len(pruning.get_thresholds()), expected_num_masks) + self.assertEqual(len(pruning.get_weights()), expected_num_masks) + + for mask in pruning.get_masks(): + self.assertEqual(mask.shape, (expected_num_rows, expected_num_cols)) + for weight in pruning.get_weights(): + self.assertEqual(weight.shape, (expected_num_rows, expected_num_cols)) + +if __name__ == '__main__': + test.main() diff --git a/tensorflow/contrib/model_pruning/python/learning.py b/tensorflow/contrib/model_pruning/python/learning.py new file mode 100644 index 0000000000..2b79c23cef --- /dev/null +++ b/tensorflow/contrib/model_pruning/python/learning.py @@ -0,0 +1,188 @@ +# Copyright 2017 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Wrapper around tf-slim's training code contrib/slim/python/slim/learning.py +to support training of pruned models + +******************************************************************* +* A simple working training script with support for model pruning * +******************************************************************* + + # Load data and create the model: + images, labels = LoadData(...) + predictions = MyModel(images) + + # Define the loss: + slim.losses.log_loss(predictions, labels) + total_loss = slim.losses.get_total_loss() + + # Define the optimizer: + optimizer = tf.train.MomentumOptimizer(FLAGS.learning_rate, FLAGS.momentum) + + # Create the train_op + train_op = slim.learning.create_train_op(total_loss, optimizer) + + # Set up sparsity + sparsity = pruning.setup_gradual_sparsity(self.global_step) + + # Create mask update op + mask_update_op = pruning.add_mask_update_ip(sparsity) + + # Run training. + learning.train(train_op, + my_log_dir, + mask_update_op) + see contrib/slim/python/slim/learning.py for additional examples +""" +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +from tensorflow.contrib import slim as _slim + +_USE_DEFAULT = 0 +train_step = _slim.learning.train_step + + +def train(train_op, + logdir, + mask_update_op, + train_step_fn=train_step, + train_step_kwargs=_USE_DEFAULT, + log_every_n_steps=1, + graph=None, + master='', + is_chief=True, + global_step=None, + number_of_steps=None, + init_op=_USE_DEFAULT, + init_feed_dict=None, + local_init_op=_USE_DEFAULT, + init_fn=None, + ready_op=_USE_DEFAULT, + summary_op=_USE_DEFAULT, + save_summaries_secs=600, + summary_writer=_USE_DEFAULT, + startup_delay_steps=0, + saver=None, + save_interval_secs=600, + sync_optimizer=None, + session_config=None, + trace_every_n_steps=None): + """Wrapper around tf-slim's train function. + + Runs a training loop using a TensorFlow supervisor. + When the sync_optimizer is supplied, gradient updates are applied + synchronously. Otherwise, gradient updates are applied asynchronous. + + Args: + train_op: A `Tensor` that, when executed, will apply the gradients and + return the loss value. + logdir: The directory where training logs are written to. If None, model + checkpoints and summaries will not be written. + mask_update_op: Operation that upon execution updates the weight masks and + thresholds. + train_step_fn: The function to call in order to execute a single gradient + step. The function must have take exactly four arguments: the current + session, the `train_op` `Tensor`, a global step `Tensor` and a dictionary. + train_step_kwargs: A dictionary which is passed to the `train_step_fn`. By + default, two `Boolean`, scalar ops called "should_stop" and "should_log" + are provided. + log_every_n_steps: The frequency, in terms of global steps, that the loss + and global step and logged. + graph: The graph to pass to the supervisor. If no graph is supplied the + default graph is used. + master: The address of the tensorflow master. + is_chief: Specifies whether or not the training is being run by the primary + replica during replica training. + global_step: The `Tensor` representing the global step. If left as `None`, + then slim.variables.get_or_create_global_step() is used. + number_of_steps: The max number of gradient steps to take during training, + as measured by 'global_step': training will stop if global_step is + greater than 'number_of_steps'. If the value is left as None, training + proceeds indefinitely. + init_op: The initialization operation. If left to its default value, then + the session is initialized by calling `tf.global_variables_initializer()`. + init_feed_dict: A feed dictionary to use when executing the `init_op`. + local_init_op: The local initialization operation. If left to its default + value, then the session is initialized by calling + `tf.local_variables_initializer()` and `tf.tables_initializer()`. + init_fn: An optional callable to be executed after `init_op` is called. The + callable must accept one argument, the session being initialized. + ready_op: Operation to check if the model is ready to use. If left to its + default value, then the session checks for readiness by calling + `tf.report_uninitialized_variables()`. + summary_op: The summary operation. + save_summaries_secs: How often, in seconds, to save summaries. + summary_writer: `SummaryWriter` to use. Can be `None` + to indicate that no summaries should be written. If unset, we + create a SummaryWriter. + startup_delay_steps: The number of steps to wait for before beginning. Note + that this must be 0 if a sync_optimizer is supplied. + saver: Saver to save checkpoints. If None, a default one will be created + and used. + save_interval_secs: How often, in seconds, to save the model to `logdir`. + sync_optimizer: an instance of tf.train.SyncReplicasOptimizer, or a list of + them. If the argument is supplied, gradient updates will be synchronous. + If left as `None`, gradient updates will be asynchronous. + session_config: An instance of `tf.ConfigProto` that will be used to + configure the `Session`. If left as `None`, the default will be used. + trace_every_n_steps: produce and save a `Timeline` in Chrome trace format + and add it to the summaries every `trace_every_n_steps`. If None, no trace + information will be produced or saved. + + Returns: + the value of the loss function after training. + + Raises: + ValueError: if `train_op` is empty or if `startup_delay_steps` is + non-zero when `sync_optimizer` is supplied, if `number_of_steps` is + negative, or if `trace_every_n_steps` is not `None` and no `logdir` is + provided. + """ + + def train_step_with_pruning_fn(sess, train_op, global_step, + train_step_kwargs): + total_loss, should_stop = train_step_fn(sess, train_op, global_step, + train_step_kwargs) + sess.run(mask_update_op) + return total_loss, should_stop + + total_loss, _ = _slim.learning.train( + train_op, + logdir, + train_step_fn=train_step_with_pruning_fn, + train_step_kwargs=train_step_kwargs, + log_every_n_steps=log_every_n_steps, + graph=graph, + master=master, + is_chief=is_chief, + global_step=global_step, + number_of_steps=number_of_steps, + init_op=init_op, + init_feed_dict=init_feed_dict, + local_init_op=local_init_op, + init_fn=init_fn, + ready_op=ready_op, + summary_op=summary_op, + save_summaries_secs=save_summaries_secs, + summary_writer=summary_writer, + startup_delay_steps=startup_delay_steps, + saver=saver, + save_interval_secs=save_interval_secs, + sync_optimizer=sync_optimizer, + session_config=session_config, + trace_every_n_steps=trace_every_n_steps) + + return total_loss diff --git a/tensorflow/contrib/model_pruning/python/pruning.py b/tensorflow/contrib/model_pruning/python/pruning.py new file mode 100644 index 0000000000..42d91a71fd --- /dev/null +++ b/tensorflow/contrib/model_pruning/python/pruning.py @@ -0,0 +1,585 @@ +# Copyright 2017 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Helper functions to add support for magnitude-based model pruning. + + # Adds variables and ops to the graph to enable + # elementwise masking of weights + apply_mask(weights) + + # Returns a list containing the sparsity of each of the weight tensors + get_weight_sparsity() + + # Returns a list of all the masked weight tensorflow variables + get_masked_weights() + + # Returns a list of all the mask tensorflow variables + get_masks() + + # Returns a list of all the thresholds + get_thresholds() + + # Returns a list of all the weight tensors that have been masked + get_weights() + + The Pruning class uses a proto (defined in pruning.proto) to set up the + parameters for a pruning specification. Here's a typical usage: + + # Initialize a pruning spec from a proto + pruning_spec = '/tmp/pruning.pb' + p = Pruning(pruning_spec) + + # Add mask update ops to the graph + mask_update_op = p.conditional_mask_update_op() + + # Add the summaries + p.add_pruning_summaries() + + # Run the op + session.run(mask_update_op) + + # An object of the pruning also accepts externally defined sparsity: + sparsity = tf.Variable(0.5, name = "ConstantSparsity") + pruning_spec = '/tmp/pruning.pb' + p = Pruning(pruning_spec, sparsity=sparsity) + +""" +# pylint: disable=missing-docstring +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import numpy as np + +from tensorflow.contrib.model_pruning.python.layers import core_layers as core +from tensorflow.contrib.training.python.training import hparam +from tensorflow.python.framework import ops +from tensorflow.python.ops import array_ops +from tensorflow.python.ops import clip_ops +from tensorflow.python.ops import control_flow_ops +from tensorflow.python.ops import gen_array_ops +from tensorflow.python.ops import init_ops +from tensorflow.python.ops import math_ops +from tensorflow.python.ops import nn_impl +from tensorflow.python.ops import state_ops +from tensorflow.python.ops import variable_scope +from tensorflow.python.platform import tf_logging as logging +from tensorflow.python.summary import summary +from tensorflow.python.training import training_util + +_MASK_COLLECTION = core.MASK_COLLECTION +_THRESHOLD_COLLECTION = core.THRESHOLD_COLLECTION +_MASKED_WEIGHT_COLLECTION = core.MASKED_WEIGHT_COLLECTION +_WEIGHT_COLLECTION = core.WEIGHT_COLLECTION +_MASKED_WEIGHT_NAME = core.MASKED_WEIGHT_NAME + + +def _weight_mask_variable(var, scope): + """Create a mask for the weights. + + This function adds a variable 'mask' to the graph. + + Args: + var: the weight variable that needs to be masked + scope: The variable scope of the variable var + + Returns: + the mask variable of the same size and shape as var, initialized to all 1s. + """ + with variable_scope.variable_scope(scope): + mask = variable_scope.get_variable( + 'mask', + var.get_shape(), + initializer=init_ops.ones_initializer(), + trainable=False, + dtype=var.dtype) + return mask + + +def _weight_threshold_variable(var, scope): + """Create a scalar threshold for the weights. + + This function adds a variable + 'threshold' to the graph. + + Args: + var: The weight variable that needs to be masked + scope: The variable scope of the variable var + + Returns: + a scalar threshold variable initialized to 0. + """ + with variable_scope.variable_scope(scope): + threshold = variable_scope.get_variable( + 'threshold', [], + initializer=init_ops.zeros_initializer(), + trainable=False, + dtype=var.dtype) + return threshold + + +def _histogram(values, value_range, nbins=100, dtype=np.int32, name=None): + """Return histogram of values. + + Given the tensor `values`, this operation returns a rank 1 histogram counting + the number of entries in `values` that fell into every bin. The bins are + equal width and determined by the arguments `value_range` and `nbins`. + + Args: + values: Numeric `Tensor`. + value_range: Shape [2] `Tensor` of same `dtype` as `values`. + values <= value_range[0] will be mapped to hist[0], + values >= value_range[1] will be mapped to hist[-1]. + nbins: Scalar `int32 Tensor`. Number of histogram bins. + dtype: dtype for returned histogram. + name: A name for this operation (defaults to 'histogram'). + + Returns: + A 1-D `Tensor` holding histogram of values. + + """ + with ops.name_scope(name, 'histogram', [values, value_range, nbins]) as scope: + values = ops.convert_to_tensor(values, name='values') + values = gen_array_ops.reshape(values, [-1]) + value_range = ops.convert_to_tensor(value_range, name='value_range') + nbins = ops.convert_to_tensor(nbins, dtype=np.int32, name='nbins') + nbins_float = math_ops.cast(nbins, values.dtype) + + # Map tensor values that fall within value_range to [0, 1]. + scaled_values = math_ops.truediv( + values - value_range[0], + value_range[1] - value_range[0], + name='scaled_values') + + # map tensor values within the open interval value_range to {0,.., nbins-1}, + # values outside the open interval will be zero or less, or nbins or more. + indices = math_ops.floor(nbins_float * scaled_values, name='indices') + + # Clip edge cases (e.g. value = value_range[1]) or "outliers." + indices = math_ops.cast( + clip_ops.clip_by_value(indices, 0, nbins_float - 1), np.int32) + + return math_ops.unsorted_segment_sum( + array_ops.ones_like(indices, dtype=dtype), indices, nbins, name=scope) + + +def _determine_partitioned_axis(partitioned_variable): + partitioned_axis = 0 + concatenated_variable_shape = partitioned_variable.get_shape() + for partition in partitioned_variable: + partition_shape = partition.get_shape() + maybe_partitioned_axis = np.less(partition_shape, + concatenated_variable_shape) + # Sanity check: make sure number of partitioned axis == 1 + if np.count_nonzero(maybe_partitioned_axis) != 1: + raise ValueError('Number of partitioned axes %s not equal to 1' % + np.count_nonzero(maybe_partitioned_axis)) + partitioned_axis = np.where(maybe_partitioned_axis)[0][0] + return partitioned_axis + + +def _variable_assign(var, new_value): + return state_ops.assign(var, new_value, name=var.op.name + '_assign') + + +def _partitioned_variable_assign(partitioned_var, new_value): + """Assign op for partitioned variables. + + Args: + partitioned_var: A partitioned tensotflow variable + new_value: Value to be assigned to the variable var + + Returns: + A tensorflow op that groups the assign ops for each of the variable slices + """ + # Determine which axis was used to partition the variable. Currently + # tensorflow allows partitioning variable only along 1 axis. + axis = 0 if len(partitioned_var) == 1 else _determine_partitioned_axis( + partitioned_var) + + partition_sizes = np.array( + [partition.get_shape()[axis] for partition in partitioned_var]) + new_partitioned_values = array_ops.split( + new_value, + ops.convert_to_tensor(partition_sizes, dtype=np.int32), + axis=axis) + op_list = [] + for partition in partitioned_var: + op_list.append( + _variable_assign(partition, new_partitioned_values[len(op_list)])) + return control_flow_ops.group( + *op_list, name=partitioned_var.name + '_group_assign') + + +def apply_mask(x, scope=''): + """Apply mask to a given weight tensor. + + Args: + x: Input weight tensor + scope: The current variable scope. Defaults to "" + Returns: + Tensor representing masked_weights + """ + + mask = _weight_mask_variable(x, scope) + threshold = _weight_threshold_variable(x, scope) + # Add masked_weights in the weights namescope so as to make it easier + # for the quantization library to add quant ops. + masked_weights = math_ops.multiply(mask, x, _MASKED_WEIGHT_NAME) + + # Make sure the mask for a given variable are not added multiple times to the + # collection. This is particularly important when applying mask to RNN's + # weight variables + if mask not in ops.get_collection_ref(_MASK_COLLECTION): + ops.add_to_collection(_THRESHOLD_COLLECTION, threshold) + ops.add_to_collection(_MASK_COLLECTION, mask) + ops.add_to_collection(_MASKED_WEIGHT_COLLECTION, masked_weights) + ops.add_to_collection(_WEIGHT_COLLECTION, x) + return masked_weights + + +def get_masked_weights(): + return ops.get_collection(_MASKED_WEIGHT_COLLECTION) + + +def get_masks(): + return ops.get_collection(_MASK_COLLECTION) + + +def get_thresholds(): + return ops.get_collection(_THRESHOLD_COLLECTION) + + +def get_weights(): + return ops.get_collection(_WEIGHT_COLLECTION) + + +def get_weight_sparsity(): + """Get sparsity of the weights. + + Args: + None + + Returns: + A list containing the sparsity of each of the weight tensors + """ + masks = get_masks() + return [nn_impl.zero_fraction(mask) for mask in masks] + + +def get_pruning_hparams(): + """Get a tf.HParams object with the default values for the hyperparameters. + + name: string + name of the pruning specification. Used for adding summaries and ops under + a common tensorflow name_scope + begin_pruning_step: integer + the global step at which to begin pruning + end_pruning_step: integer + the global step at which to terminate pruning. Defaults to -1 implying + that pruning continues till the training stops + do_not_prune: list of strings + list of layers that are not pruned + threshold_decay: float + the decay factor to use for exponential decay of the thresholds + pruning_frequency: integer + How often should the masks be updated? (in # of global_steps) + nbins: integer + number of bins to use for histogram computation + initial_sparsity: float + initial sparsity value + target_sparsity: float + target sparsity value + sparsity_function_begin_step: integer + the global step at this which the gradual sparsity function begins to + take effect + sparsity_function_end_step: integer + the global step used as the end point for the gradual sparsity function + sparsity_function_exponent: float + exponent = 1 is linearly varying sparsity between initial and final. + exponent > 1 varies more slowly towards the end than the beginning + + We use the following sparsity function: + + num_steps = (sparsity_function_end_step - + sparsity_function_begin_step)/pruning_frequency + sparsity(step) = (initial_sparsity - target_sparsity)* + [1-step/(num_steps -1)]**exponent + target_sparsity + + Args: + None + + Returns: + tf.HParams object initialized to default values + + """ + return hparam.HParams( + name='model_pruning', + begin_pruning_step=0, + end_pruning_step=-1, + do_not_prune=[''], + threshold_decay=0.9, + pruning_frequency=10, + nbins=255, + initial_sparsity=0, + target_sparsity=0.5, + sparsity_function_begin_step=0, + sparsity_function_end_step=100, + sparsity_function_exponent=3) + + +class Pruning(object): + + def __init__(self, + spec=None, + global_step=None, + sparsity=None, + partitioner=None): + """Set up the specification for model pruning. + + If a spec is provided, the sparsity is set up based on the sparsity_function + in the spec. The effect of sparsity_function is overridden if the sparsity + variable is passed to the constructor. This enables setting up arbitrary + sparsity profiles externally and passing it to this pruning functions. + + Args: + spec: Pruning spec as defined in pruning.proto + global_step: A tensorflow variable that is used while setting up the + sparsity function + sparsity: A tensorflow scalar variable storing the sparsity + partitioner: The tensorflow partitioner function used to distribute + parameters across shards + """ + # Pruning specification + self._spec = spec if spec else get_pruning_hparams() + + # A tensorflow variable that tracks the sparsity function. + # If not provided as input, the graph must already contain the global_step + # variable before calling this constructor. + self._global_step = self._setup_global_step(global_step) + + # Stores the tensorflow sparsity variable. + # Built using self._setup_sparsity() or provided externally + self._sparsity = sparsity if sparsity else self._setup_sparsity() + + # Stores the partitioner function uses to partition variables across tasks/ + self._partitioner = partitioner + + # List of tensorflow assignments ops for new masks and thresholds + self._assign_ops = [] + + # Tensorflow variable keeping track of the last global step when the masks + # were updated + self._last_update_step = self._setup_last_update_step() + + def _setup_global_step(self, global_step): + graph_global_step = global_step + if graph_global_step is None: + graph_global_step = training_util.get_global_step() + + return math_ops.cast(graph_global_step, np.int32) + + def _setup_sparsity(self): + begin_step = self._spec.sparsity_function_begin_step + end_step = self._spec.sparsity_function_end_step + initial_sparsity = self._spec.initial_sparsity + target_sparsity = self._spec.target_sparsity + exponent = self._spec.sparsity_function_exponent + + if begin_step >= end_step: + raise ValueError( + 'Pruning must begin before it can end. begin_step=%d, end_step=%d' % + (begin_step, end_step)) + + with ops.name_scope(self._spec.name): + p = math_ops.minimum(1.0, + math_ops.maximum( + 0.0, + math_ops.div( + math_ops.cast(self._global_step - begin_step, + np.float32), + end_step - begin_step))) + sparsity = math_ops.add( + math_ops.multiply(initial_sparsity - target_sparsity, + math_ops.pow(1 - p, exponent)), + target_sparsity, + name='sparsity') + + return sparsity + + def _setup_last_update_step(self): + with variable_scope.variable_scope(self._spec.name) as scope: + try: + last_update_step = variable_scope.get_variable( + 'last_mask_update_step', [], + initializer=init_ops.zeros_initializer(), + trainable=False, + dtype=np.int32) + except ValueError: + scope.reuse_variables() + last_update_step = variable_scope.get_variable( + 'last_mask_update_step', dtype=np.int32) + return last_update_step + + def _exists_in_do_not_prune_list(self, tensor_name): + do_not_prune_list = self._spec.do_not_prune + if not do_not_prune_list[0]: + return False + for layer_name in do_not_prune_list: + if tensor_name.find(layer_name) != -1: + return True + + return False + + def _update_mask(self, weights, threshold): + """Updates the mask for a given weight tensor. + + This functions first computes the cdf of the weight tensor, and estimates + the threshold value such that 'desired_sparsity' fraction of weights + have magnitude less than the threshold. + + Args: + weights: The weight tensor that needs to be masked. + threshold: The current threshold value. The function will compute a new + threshold and return the exponential moving average using the current + value of threshold + + Returns: + new_threshold: The new value of the threshold based on weights, and + desired_sparsity + new_mask: A n-D numpy array containing 0 or 1 to indicate which of the + values in weights falls below the threshold + + Raises: + ValueError: if sparsity is not defined + """ + if self._sparsity is None: + raise ValueError('Sparsity variable undefined') + + with ops.name_scope(weights.op.name + '_pruning_ops'): + abs_weights = math_ops.abs(weights) + max_value = math_ops.reduce_max(abs_weights) + histogram = _histogram( + abs_weights, [0.0, max_value], + nbins=self._spec.nbins, + dtype=np.float32) + + cdf = math_ops.cumsum(histogram) + norm_cdf = math_ops.div(cdf, math_ops.reduce_sum(histogram)) + current_threshold = math_ops.multiply( + math_ops.div( + math_ops.reduce_sum( + math_ops.cast( + math_ops.less(norm_cdf, self._sparsity), np.float32)), + float(self._spec.nbins)), max_value) + + smoothed_threshold = math_ops.add_n([ + math_ops.multiply(current_threshold, 1 - self._spec.threshold_decay), + math_ops.multiply(threshold, self._spec.threshold_decay) + ]) + new_mask = math_ops.cast( + math_ops.greater(abs_weights, smoothed_threshold), np.float32) + return smoothed_threshold, new_mask + + def _get_mask_assign_ops(self): + # Make sure the assignment ops have not already been added to the list + if self._assign_ops: + raise ValueError( + 'Assign op list not empty. _get_mask_assign_ops() called twice?') + + masks = get_masks() + weights = get_weights() + thresholds = get_thresholds() + + if len(masks) != len(thresholds): + raise ValueError( + 'Number of masks %s and number of thresholds %s mismatch' % + (len(masks), len(thresholds))) + + for index, mask in enumerate(masks): + threshold = thresholds[index] + weight = weights[index] if self._partitioner is None else weights[ + index].as_tensor() + + if self._spec.do_not_prune: + if self._exists_in_do_not_prune_list(mask.name): + continue + + new_threshold, new_mask = self._update_mask(weight, threshold) + self._assign_ops.append(_variable_assign(threshold, new_threshold)) + self._assign_ops.append( + _variable_assign(mask, new_mask) if self._partitioner is None else + _partitioned_variable_assign(mask, new_mask)) + + def mask_update_op(self): + with ops.name_scope(self._spec.name): + if not self._assign_ops: + self._get_mask_assign_ops() + with ops.control_dependencies([ + state_ops.assign( + self._last_update_step, + self._global_step, + name='last_mask_update_step_assign') + ]): + with ops.control_dependencies(self._assign_ops): + logging.info('Updating masks.') + return control_flow_ops.no_op('mask_update') + + def conditional_mask_update_op(self): + + def maybe_update_masks(): + with ops.name_scope(self._spec.name): + is_step_within_pruning_range = math_ops.logical_and( + math_ops.greater_equal(self._global_step, + self._spec.begin_pruning_step), + # If end_pruning_step is negative, keep pruning forever! + math_ops.logical_or( + math_ops.less_equal(self._global_step, + self._spec.end_pruning_step), + math_ops.less(self._spec.end_pruning_step, 0))) + is_pruning_step = math_ops.less_equal( + math_ops.add(self._last_update_step, self._spec.pruning_frequency), + self._global_step) + return math_ops.logical_and(is_step_within_pruning_range, + is_pruning_step) + + def mask_update_op(): + return self.mask_update_op() + + def no_update_op(): + return control_flow_ops.no_op() + + return control_flow_ops.cond(maybe_update_masks(), mask_update_op, + no_update_op) + + def add_pruning_summaries(self): + """Adds summaries for this pruning spec. + + Args: none + + Returns: none + """ + with ops.name_scope(self._spec.name + '_summaries'): + summary.scalar('sparsity', self._sparsity) + summary.scalar('last_mask_update_step', self._last_update_step) + masks = get_masks() + thresholds = get_thresholds() + for index, mask in enumerate(masks): + if not self._exists_in_do_not_prune_list(mask.name): + summary.scalar(mask.name + '/sparsity', nn_impl.zero_fraction(mask)) + summary.scalar(thresholds[index].op.name + '/threshold', + thresholds[index]) + + def print_hparams(self): + logging.info(self._spec.to_json()) diff --git a/tensorflow/contrib/model_pruning/python/pruning_test.py b/tensorflow/contrib/model_pruning/python/pruning_test.py new file mode 100644 index 0000000000..c23fd649ce --- /dev/null +++ b/tensorflow/contrib/model_pruning/python/pruning_test.py @@ -0,0 +1,162 @@ +# Copyright 2017 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Tests for the key functions in pruning library.""" + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import numpy as np + +from tensorflow.contrib.model_pruning.python import pruning +from tensorflow.python.ops import math_ops +from tensorflow.python.ops import partitioned_variables +from tensorflow.python.ops import random_ops +from tensorflow.python.ops import state_ops +from tensorflow.python.ops import variable_scope +from tensorflow.python.ops import variables +from tensorflow.python.platform import test +from tensorflow.python.training import training_util + + +class PruningHParamsTest(test.TestCase): + PARAM_LIST = [ + "name=test", "threshold_decay=0.9", "pruning_frequency=10", + "do_not_prune=[conv1,conv2]", "sparsity_function_end_step=100", + "target_sparsity=0.9" + ] + TEST_HPARAMS = ",".join(PARAM_LIST) + + def setUp(self): + super(PruningHParamsTest, self).setUp() + # Add global step variable to the graph + self.global_step = training_util.get_or_create_global_step() + # Add sparsity + self.sparsity = variables.Variable(0.5, name="sparsity") + # Parse hparams + self.pruning_hparams = pruning.get_pruning_hparams().parse( + self.TEST_HPARAMS) + + def testInit(self): + p = pruning.Pruning(self.pruning_hparams) + self.assertEqual(p._spec.name, "test") + self.assertAlmostEqual(p._spec.threshold_decay, 0.9) + self.assertEqual(p._spec.pruning_frequency, 10) + self.assertAllEqual(p._spec.do_not_prune, ["conv1", "conv2"]) + self.assertEqual(p._spec.sparsity_function_end_step, 100) + self.assertAlmostEqual(p._spec.target_sparsity, 0.9) + + def testInitWithExternalSparsity(self): + with self.test_session(): + p = pruning.Pruning(spec=self.pruning_hparams, sparsity=self.sparsity) + variables.global_variables_initializer().run() + sparsity = p._sparsity.eval() + self.assertAlmostEqual(sparsity, 0.5) + + def testInitWithVariableReuse(self): + with self.test_session(): + p = pruning.Pruning(spec=self.pruning_hparams, sparsity=self.sparsity) + p_copy = pruning.Pruning( + spec=self.pruning_hparams, sparsity=self.sparsity) + variables.global_variables_initializer().run() + sparsity = p._sparsity.eval() + self.assertAlmostEqual(sparsity, 0.5) + self.assertEqual(p._sparsity.eval(), p_copy._sparsity.eval()) + + +class PruningTest(test.TestCase): + + def setUp(self): + super(PruningTest, self).setUp() + self.global_step = training_util.get_or_create_global_step() + + def testCreateMask2D(self): + width = 10 + height = 20 + with self.test_session(): + weights = variables.Variable( + random_ops.random_normal([width, height], stddev=1), name="weights") + masked_weights = pruning.apply_mask(weights, + variable_scope.get_variable_scope()) + variables.global_variables_initializer().run() + weights_val = weights.eval() + masked_weights_val = masked_weights.eval() + self.assertAllEqual(weights_val, masked_weights_val) + + def testUpdateSingleMask(self): + with self.test_session() as session: + weights = variables.Variable( + math_ops.linspace(1.0, 100.0, 100), name="weights") + masked_weights = pruning.apply_mask(weights) + sparsity = variables.Variable(0.5, name="sparsity") + p = pruning.Pruning(sparsity=sparsity) + p._spec.threshold_decay = 0.0 + mask_update_op = p.mask_update_op() + variables.global_variables_initializer().run() + masked_weights_val = masked_weights.eval() + self.assertAllEqual(np.count_nonzero(masked_weights_val), 100) + session.run(mask_update_op) + masked_weights_val = masked_weights.eval() + self.assertAllEqual(np.count_nonzero(masked_weights_val), 51) + + def testPartitionedVariableMasking(self): + partitioner = partitioned_variables.variable_axis_size_partitioner(40) + with self.test_session() as session: + with variable_scope.variable_scope("", partitioner=partitioner): + sparsity = variables.Variable(0.5, name="Sparsity") + weights = variable_scope.get_variable( + "weights", initializer=math_ops.linspace(1.0, 100.0, 100)) + masked_weights = pruning.apply_mask( + weights, scope=variable_scope.get_variable_scope()) + p = pruning.Pruning(sparsity=sparsity, partitioner=partitioner) + p._spec.threshold_decay = 0.0 + mask_update_op = p.mask_update_op() + variables.global_variables_initializer().run() + masked_weights_val = masked_weights.eval() + session.run(mask_update_op) + masked_weights_val = masked_weights.eval() + self.assertAllEqual(np.count_nonzero(masked_weights_val), 51) + + def testConditionalMaskUpdate(self): + param_list = [ + "pruning_frequency=2", "begin_pruning_step=1", "end_pruning_step=6" + ] + test_spec = ",".join(param_list) + pruning_hparams = pruning.get_pruning_hparams().parse(test_spec) + weights = variables.Variable( + math_ops.linspace(1.0, 100.0, 100), name="weights") + masked_weights = pruning.apply_mask(weights) + sparsity = variables.Variable(0.00, name="sparsity") + # Set up pruning + p = pruning.Pruning(pruning_hparams, sparsity=sparsity) + p._spec.threshold_decay = 0.0 + mask_update_op = p.conditional_mask_update_op() + sparsity_val = math_ops.linspace(0.0, 0.9, 10) + increment_global_step = state_ops.assign_add(self.global_step, 1) + non_zero_count = [] + with self.test_session() as session: + variables.global_variables_initializer().run() + for i in range(10): + session.run(state_ops.assign(sparsity, sparsity_val[i])) + session.run(mask_update_op) + session.run(increment_global_step) + non_zero_count.append(np.count_nonzero(masked_weights.eval())) + # Weights pruned at steps 0,2,4,and,6 + expected_non_zero_count = [100, 100, 80, 80, 60, 60, 40, 40, 40, 40] + self.assertAllEqual(expected_non_zero_count, non_zero_count) + + +if __name__ == "__main__": + test.main() -- cgit v1.2.3