aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/contrib/model_pruning
diff options
context:
space:
mode:
authorGravatar A. Unique TensorFlower <gardener@tensorflow.org>2017-11-01 11:55:32 -0700
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2017-11-01 11:58:50 -0700
commit7ece1c0b8e527d59d8082cd6428cd255e5700074 (patch)
tree5af28c04f411ec5a7b3cd8a74172459d59806521 /tensorflow/contrib/model_pruning
parent693325c83255f1ec95744f3b92da3b1b075b1259 (diff)
Moving model_pruning library to tf.contrib
PiperOrigin-RevId: 174214419
Diffstat (limited to 'tensorflow/contrib/model_pruning')
-rw-r--r--tensorflow/contrib/model_pruning/BUILD139
-rw-r--r--tensorflow/contrib/model_pruning/README.md195
-rw-r--r--tensorflow/contrib/model_pruning/__init__.py46
-rw-r--r--tensorflow/contrib/model_pruning/examples/cifar10/BUILD77
-rw-r--r--tensorflow/contrib/model_pruning/examples/cifar10/cifar10_eval.py178
-rw-r--r--tensorflow/contrib/model_pruning/examples/cifar10/cifar10_input.py256
-rw-r--r--tensorflow/contrib/model_pruning/examples/cifar10/cifar10_pruning.py395
-rw-r--r--tensorflow/contrib/model_pruning/examples/cifar10/cifar10_train.py159
-rw-r--r--tensorflow/contrib/model_pruning/python/layers/core_layers.py477
-rw-r--r--tensorflow/contrib/model_pruning/python/layers/layers.py364
-rw-r--r--tensorflow/contrib/model_pruning/python/layers/layers_test.py139
-rw-r--r--tensorflow/contrib/model_pruning/python/layers/rnn_cells.py340
-rw-r--r--tensorflow/contrib/model_pruning/python/layers/rnn_cells_test.py85
-rw-r--r--tensorflow/contrib/model_pruning/python/learning.py188
-rw-r--r--tensorflow/contrib/model_pruning/python/pruning.py585
-rw-r--r--tensorflow/contrib/model_pruning/python/pruning_test.py162
16 files changed, 3785 insertions, 0 deletions
diff --git a/tensorflow/contrib/model_pruning/BUILD b/tensorflow/contrib/model_pruning/BUILD
new file mode 100644
index 0000000000..ca3f13479e
--- /dev/null
+++ b/tensorflow/contrib/model_pruning/BUILD
@@ -0,0 +1,139 @@
+# Copyright 2017 The TensorFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+package(default_visibility = ["//tensorflow:__subpackages__"])
+
+licenses(["notice"]) # Apache 2.0
+
+load("//tensorflow:tensorflow.bzl", "py_test")
+
+py_library(
+ name = "core_layers",
+ srcs = ["python/layers/core_layers.py"],
+ srcs_version = "PY2AND3",
+ visibility = ["//visibility:public"],
+ deps = [
+ "//tensorflow/python:layers",
+ "//tensorflow/python:ops",
+ "//tensorflow/python:platform",
+ ],
+)
+
+py_library(
+ name = "layers",
+ srcs = ["python/layers/layers.py"],
+ srcs_version = "PY2AND3",
+ deps = [
+ ":core_layers",
+ "//tensorflow/contrib/framework:framework_py",
+ "//tensorflow/contrib/layers:layers_py",
+ "//third_party/py/numpy",
+ ],
+)
+
+py_test(
+ name = "layers_test",
+ size = "small",
+ srcs = ["python/layers/layers_test.py"],
+ srcs_version = "PY2AND3",
+ deps = [
+ ":layers",
+ "//tensorflow/python:client_testlib",
+ ],
+)
+
+py_library(
+ name = "learning",
+ srcs = ["python/learning.py"],
+ srcs_version = "PY2AND3",
+ deps = [
+ "//tensorflow/contrib/slim",
+ ],
+)
+
+py_library(
+ name = "rnn_cells",
+ srcs = ["python/layers/rnn_cells.py"],
+ srcs_version = "PY2AND3",
+ deps = [
+ ":core_layers",
+ ],
+)
+
+py_library(
+ name = "pruning",
+ srcs = ["python/pruning.py"],
+ srcs_version = "PY2AND3",
+ visibility = ["//visibility:public"],
+ deps = [
+ ":core_layers",
+ "//tensorflow/contrib/training:training_py",
+ "//tensorflow/python:platform",
+ "//third_party/py/numpy",
+ ],
+)
+
+py_test(
+ name = "pruning_test",
+ size = "small",
+ srcs = ["python/pruning_test.py"],
+ srcs_version = "PY2AND3",
+ deps = [
+ ":pruning",
+ "//tensorflow/python:client_testlib",
+ ],
+)
+
+py_test(
+ name = "rnn_cells_test",
+ size = "small",
+ srcs = ["python/layers/rnn_cells_test.py"],
+ srcs_version = "PY2AND3",
+ deps = [
+ ":pruning",
+ ":rnn_cells",
+ "//tensorflow/python:client_testlib",
+ ],
+)
+
+py_library(
+ name = "init_py",
+ srcs = ["__init__.py"],
+ srcs_version = "PY2AND3",
+)
+
+# Top-level library
+py_library(
+ name = "model_pruning",
+ srcs_version = "PY2AND3",
+ deps = [
+ ":init_py",
+ ":layers",
+ ":learning",
+ ":pruning",
+ ":rnn_cells",
+ ],
+)
+
+filegroup(
+ name = "all_files",
+ srcs = glob(
+ ["**/*"],
+ exclude = [
+ "**/METADATA",
+ "**/OWNERS",
+ ],
+ ),
+ visibility = ["//tensorflow:__subpackages__"],
+)
diff --git a/tensorflow/contrib/model_pruning/README.md b/tensorflow/contrib/model_pruning/README.md
new file mode 100644
index 0000000000..a8427e6014
--- /dev/null
+++ b/tensorflow/contrib/model_pruning/README.md
@@ -0,0 +1,195 @@
+# Model pruning: Training tensorflow models to have masked connections
+
+This document describes the API that facilitates magnitude-based pruning of
+neural network's weight tensors. The API helps inject necessary tensorflow op
+into the training graph so the model can be pruned while it is being trained.
+
+### Model creation
+
+The first step involves adding mask and threshold variables to the layers that
+need to undergo pruning. The variable mask is the same shape as the layer's
+weight tensor and determines which of the weights participate in the forward
+execution of the graph. This can be achieved by wrapping the weight tensor of
+the layer with the `apply_mask` function provided in
+[pruning.py](https://github.com/tensorflow/tensorflow/tree/master/tensorflow/contrib/model_pruning/python/pruning.py).
+For example:
+
+```python
+conv = tf.nn.conv2d(images, pruning.apply_mask(weights), stride, padding)
+```
+
+This creates a convolutional layer with additional variables mask and threshold
+as shown below: ![Convolutional layer with mask and
+threshold](./mask.png "Convolutional layer with mask and threshold")
+
+Alternatively, the API also provides variant of tensorflow layers with these
+auxiliary variables built-in (see
+[layers](https://github.com/tensorflow/tensorflow/tree/master/tensorflow/contrib/model_pruning/python/layers))
+. Layers currently supported:
+
+* [layers.masked_conv2d](https://github.com/tensorflow/tensorflow/tree/master/tensorflow/contrib/model_pruning/python/layers/layers.py?l=83)
+
+* [layers.masked_fully_connected](https://github.com/tensorflow/tensorflow/tree/master/tensorflow/contrib/model_pruning/python/layers/layers.py?l=241)
+
+* [rnn_cells.MaskedLSTMCell](https://github.com/tensorflow/tensorflow/tree/master/tensorflow/contrib/model_pruning/python/layers/rnn_cells.py?l=154)
+
+### Adding pruning ops to the training graph
+
+The pruning library allows for specification of the following hyper parameters:
+
+| Hyperparameter | Type | Default | Description |
+| ---------------------------- | ------- | ------------- | -------------- |
+| name | string | model_pruning | Name of the |
+: : : : pruning :
+: : : : specification. :
+: : : : Used for :
+: : : : adding :
+: : : : summaries and :
+: : : : ops under a :
+: : : : common :
+: : : : tensorflow :
+: : : : name_scope :
+| begin_pruning_step | integer | 0 | The global |
+: : : : step at which :
+: : : : to begin :
+: : : : pruning :
+| end_pruning_step | integer | -1 | The global |
+: : : : step at which :
+: : : : to terminate :
+: : : : pruning. :
+: : : : Defaults to -1 :
+: : : : implying that :
+: : : : pruning :
+: : : : continues till :
+: : : : the training :
+: : : : stops :
+| do_not_prune | list of | [""] | list of layers |
+: : strings : : that are not :
+: : : : pruned :
+| threshold_decay | float | 0.9 | The decay |
+: : : : factor to use :
+: : : : for :
+: : : : exponential :
+: : : : decay of the :
+: : : : thresholds :
+| pruning_frequency | integer | 10 | How often |
+: : : : should the :
+: : : : masks be :
+: : : : updated? (in # :
+: : : : of :
+: : : : global_steps). :
+| nbins | integer | 255 | Number of bins |
+: : : : to use for :
+: : : : histogram :
+: : : : computation :
+| initial_sparsity | float | 0.0 | Initial |
+: : : : sparsity value :
+| target_sparsity | float | 0.5 | Target |
+: : : : sparsity value :
+| sparsity_function_begin_step | integer | 0 | The global |
+: : : : step at this :
+: : : : which the :
+: : : : gradual :
+: : : : sparsity :
+: : : : function :
+: : : : begins to take :
+: : : : effect :
+| sparsity_function_end_step | integer | 100 | The global |
+: : : : step used as :
+: : : : the end point :
+: : : : for the :
+: : : : gradual :
+: : : : sparsity :
+: : : : function :
+| sparsity_function_exponent | float | 3.0 | exponent = 1 |
+: : : : is linearly :
+: : : : varying :
+: : : : sparsity :
+: : : : between :
+: : : : initial and :
+: : : : final. :
+: : : : exponent > 1 :
+: : : : varies more :
+: : : : slowly towards :
+: : : : the end than :
+: : : : the beginning :
+
+The sparsity $$s_t$$ at global step $$t$$ is given by:
+
+$$ s_{t}=s_{f}+\left(s_{i}-s_{f}\right)\left(1-\frac{t-t_{0}}{n\Delta t}\right)^{3} $$
+
+The interval between sparsity_function_begin_step and sparsity_function_end_step
+is divided into $$n$$ intervals of size equal to the pruning_frequency ($$\Delta
+t$$). $$s_f$$ is the target_sparsity, $$s_i$$ is the initial_sparsity, $$t_0$$
+is the sparsity_function_begin_step. In this equation, the
+sparsity_function_exponent is set to 3.
+### Adding pruning ops to the training graph
+
+The final step involves adding ops to the training graph that monitors the
+distribution of the layer's weight magnitudes and determines the layer threshold
+such masking all the weights below this threshold achieves the sparsity level
+desired for the current training step. This can be achieved as follows:
+
+```python
+tf.app.flags.DEFINE_string(
+ 'pruning_hparams', '',
+ """Comma separated list of pruning-related hyperparameters""")
+
+with tf.graph.as_default():
+
+ # Create global step variable
+ global_step = tf.train.get_global_step()
+
+ # Parse pruning hyperparameters
+ pruning_hparams = pruning.get_pruning_hparams().parse(FLAGS.pruning_hparams)
+
+ # Create a pruning object using the pruning specification
+ p = pruning.Pruning(pruning_hparams, global_step=global_step)
+
+ # Add conditional mask update op. Executing this op will update all
+ # the masks in the graph if the current global step is in the range
+ # [begin_pruning_step, end_pruning_step] as specified by the pruning spec
+ mask_update_op = p.conditional_mask_update_op()
+
+ # Add summaries to keep track of the sparsity in different layers during training
+ p.add_pruning_summaries()
+
+ with tf.train.MonitoredTrainingSession(...) as mon_sess:
+ # Run the usual training op in the tf session
+ mon_sess.run(train_op)
+
+ # Update the masks by running the mask_update_op
+ mon_sess.run(mask_update_op)
+
+```
+
+## Example: Pruning and training deep CNNs on the cifar10 dataset
+
+Please see https://www.tensorflow.org/tutorials/deep_cnn for details on neural
+network architecture, setting up inputs etc. The additional changes needed to
+incorporate pruning are captured in the following:
+
+* [cifar10_pruning.py](https://github.com/tensorflow/tensorflow/tree/master/tensorflow/contrib/model_pruning/examples/cifar10/cifar10_pruning.py)
+ creates a deep CNN with the same architecture, but adds mask and threshold
+ variables for each of the weight tensors in the convolutional and
+ locally-connected layers.
+
+* [cifar10_train.py](https://github.com/tensorflow/tensorflow/tree/master/tensorflow/contrib/model_pruning/examples/cifar10/cifar10_train.py)
+ add pruning ops to the training graph as described above.
+
+To train the pruned version of cifar10:
+
+```bash
+$ examples_dir=contrib/model_pruning/examples
+$ bazel build -c opt $examples_dir/cifar10:cifar10_{train,eval}
+$ bazel-bin/$examples_dir/cifar10/cifar10_train --pruning_hparams=name=cifar10_pruning,begin_pruning_step=10000,end_pruning_step=100000,target_sparsity=0.9,sparsity_function_begin_step=10000,sparsity_function_end_step=100000
+```
+
+Eval:
+
+```shell
+$ bazel-bin/$examples_dir/cifar10/cifar10_eval --run_once
+```
+
+TODO(suyoggupta): Add figures showing the sparsity function, sparsity for
+different layers etc.
diff --git a/tensorflow/contrib/model_pruning/__init__.py b/tensorflow/contrib/model_pruning/__init__.py
new file mode 100644
index 0000000000..aaeb2238a4
--- /dev/null
+++ b/tensorflow/contrib/model_pruning/__init__.py
@@ -0,0 +1,46 @@
+# Copyright 2017 The TensorFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+"""Model pruning implementation in tensorflow."""
+
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+# pylint: disable=unused-import
+from tensorflow.contrib.model_pruning.python.layers.layers import masked_conv2d
+from tensorflow.contrib.model_pruning.python.layers.layers import masked_convolution
+from tensorflow.contrib.model_pruning.python.layers.layers import masked_fully_connected
+from tensorflow.contrib.model_pruning.python.layers.rnn_cells import MaskedBasicLSTMCell
+from tensorflow.contrib.model_pruning.python.layers.rnn_cells import MaskedLSTMCell
+from tensorflow.contrib.model_pruning.python.learning import train
+from tensorflow.contrib.model_pruning.python.pruning import apply_mask
+from tensorflow.contrib.model_pruning.python.pruning import get_masked_weights
+from tensorflow.contrib.model_pruning.python.pruning import get_masks
+from tensorflow.contrib.model_pruning.python.pruning import get_thresholds
+from tensorflow.contrib.model_pruning.python.pruning import get_weight_sparsity
+from tensorflow.contrib.model_pruning.python.pruning import get_weights
+from tensorflow.contrib.model_pruning.python.pruning import Pruning
+# pylint: enable=unused-import
+
+from tensorflow.python.util.all_util import remove_undocumented
+
+_allowed_symbols = [
+ 'masked_convolution', 'masked_conv2d', 'masked_fully_connected',
+ 'MaskedBasicLSTMCell', 'MaskedLSTMCell', 'train', 'apply_mask',
+ 'get_masked_weights', 'get_masks', 'get_thresholds', 'get_weights',
+ 'get_weight_sparsity', 'Pruning'
+]
+
+remove_undocumented(__name__, _allowed_symbols)
diff --git a/tensorflow/contrib/model_pruning/examples/cifar10/BUILD b/tensorflow/contrib/model_pruning/examples/cifar10/BUILD
new file mode 100644
index 0000000000..299278ae75
--- /dev/null
+++ b/tensorflow/contrib/model_pruning/examples/cifar10/BUILD
@@ -0,0 +1,77 @@
+# Copyright 2017 The TensorFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+# Description:
+# Example TensorFlow models for CIFAR-10
+
+package(
+ default_visibility = [
+ "//tensorflow:internal",
+ ],
+)
+
+licenses(["notice"]) # Apache 2.0
+
+py_library(
+ name = "cifar10_input",
+ srcs = ["cifar10_input.py"],
+ srcs_version = "PY2AND3",
+ deps = [
+ "//tensorflow:tensorflow_py",
+ ],
+)
+
+py_library(
+ name = "cifar10_pruning",
+ srcs = ["cifar10_pruning.py"],
+ srcs_version = "PY2AND3",
+ deps = [
+ ":cifar10_input",
+ "//tensorflow:tensorflow_py",
+ ],
+)
+
+py_binary(
+ name = "cifar10_eval",
+ srcs = [
+ "cifar10_eval.py",
+ ],
+ srcs_version = "PY2AND3",
+ deps = [
+ ":cifar10_pruning",
+ ],
+)
+
+py_binary(
+ name = "cifar10_train",
+ srcs = [
+ "cifar10_train.py",
+ ],
+ srcs_version = "PY2AND3",
+ deps = [
+ ":cifar10_pruning",
+ ],
+)
+
+filegroup(
+ name = "all_files",
+ srcs = glob(
+ ["**/*"],
+ exclude = [
+ "**/METADATA",
+ "**/OWNERS",
+ ],
+ ),
+ visibility = ["//tensorflow:__subpackages__"],
+)
diff --git a/tensorflow/contrib/model_pruning/examples/cifar10/cifar10_eval.py b/tensorflow/contrib/model_pruning/examples/cifar10/cifar10_eval.py
new file mode 100644
index 0000000000..d72b2a1dca
--- /dev/null
+++ b/tensorflow/contrib/model_pruning/examples/cifar10/cifar10_eval.py
@@ -0,0 +1,178 @@
+# Copyright 2017 The TensorFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+"""Evaluation for CIFAR-10.
+
+Accuracy:
+cifar10_train.py achieves 83.0% accuracy after 100K steps (256 epochs
+of data) as judged by cifar10_eval.py.
+
+Speed:
+On a single Tesla K40, cifar10_train.py processes a single batch of 128 images
+in 0.25-0.35 sec (i.e. 350 - 600 images /sec). The model reaches ~86%
+accuracy after 100K steps in 8 hours of training time.
+
+Usage:
+Please see the tutorial and website for how to download the CIFAR-10
+data set, compile the program and train the model.
+
+http://tensorflow.org/tutorials/deep_cnn/
+"""
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+import argparse
+import datetime
+import math
+import sys
+import time
+
+import numpy as np
+import tensorflow as tf
+
+from tensorflow.contrib.model_pruning.examples.cifar10 import cifar10_pruning as cifar10
+
+FLAGS = None
+
+
+def eval_once(saver, summary_writer, top_k_op, summary_op):
+ """Run Eval once.
+
+ Args:
+ saver: Saver.
+ summary_writer: Summary writer.
+ top_k_op: Top K op.
+ summary_op: Summary op.
+ """
+ with tf.Session() as sess:
+ ckpt = tf.train.get_checkpoint_state(FLAGS.checkpoint_dir)
+ if ckpt and ckpt.model_checkpoint_path:
+ # Restores from checkpoint
+ saver.restore(sess, ckpt.model_checkpoint_path)
+ # Assuming model_checkpoint_path looks something like:
+ # /my-favorite-path/cifar10_train/model.ckpt-0,
+ # extract global_step from it.
+ global_step = ckpt.model_checkpoint_path.split('/')[-1].split('-')[-1]
+ else:
+ print('No checkpoint file found')
+ return
+
+ # Start the queue runners.
+ coord = tf.train.Coordinator()
+ try:
+ threads = []
+ for qr in tf.get_collection(tf.GraphKeys.QUEUE_RUNNERS):
+ threads.extend(qr.create_threads(sess, coord=coord, daemon=True,
+ start=True))
+
+ num_iter = int(math.ceil(FLAGS.num_examples / 128))
+ true_count = 0 # Counts the number of correct predictions.
+ total_sample_count = num_iter * 128
+ step = 0
+ while step < num_iter and not coord.should_stop():
+ predictions = sess.run([top_k_op])
+ true_count += np.sum(predictions)
+ step += 1
+
+ # Compute precision @ 1.
+ precision = true_count / total_sample_count
+ print('%s: precision @ 1 = %.3f' % (datetime.datetime.now(), precision))
+
+ summary = tf.Summary()
+ summary.ParseFromString(sess.run(summary_op))
+ summary.value.add(tag='Precision @ 1', simple_value=precision)
+ summary_writer.add_summary(summary, global_step)
+ except Exception as e: # pylint: disable=broad-except
+ coord.request_stop(e)
+
+ coord.request_stop()
+ coord.join(threads, stop_grace_period_secs=10)
+
+
+def evaluate():
+ """Eval CIFAR-10 for a number of steps."""
+ with tf.Graph().as_default() as g:
+ # Get images and labels for CIFAR-10.
+ eval_data = FLAGS.eval_data == 'test'
+ images, labels = cifar10.inputs(eval_data=eval_data)
+
+ # Build a Graph that computes the logits predictions from the
+ # inference model.
+ logits = cifar10.inference(images)
+
+ # Calculate predictions.
+ top_k_op = tf.nn.in_top_k(logits, labels, 1)
+
+ # Restore the moving average version of the learned variables for eval.
+ variable_averages = tf.train.ExponentialMovingAverage(
+ cifar10.MOVING_AVERAGE_DECAY)
+ variables_to_restore = variable_averages.variables_to_restore()
+ saver = tf.train.Saver(variables_to_restore)
+
+ # Build the summary operation based on the TF collection of Summaries.
+ summary_op = tf.summary.merge_all()
+
+ summary_writer = tf.summary.FileWriter(FLAGS.eval_dir, g)
+
+ while True:
+ eval_once(saver, summary_writer, top_k_op, summary_op)
+ if FLAGS.run_once:
+ break
+ time.sleep(FLAGS.eval_interval_secs)
+
+
+def main(argv=None): # pylint: disable=unused-argument
+ cifar10.maybe_download_and_extract()
+ if tf.gfile.Exists(FLAGS.eval_dir):
+ tf.gfile.DeleteRecursively(FLAGS.eval_dir)
+ tf.gfile.MakeDirs(FLAGS.eval_dir)
+ evaluate()
+
+
+if __name__ == '__main__':
+ parser = argparse.ArgumentParser()
+ parser.add_argument(
+ '--eval_dir',
+ type=str,
+ default='/tmp/cifar10_eval',
+ help='Directory where to write event logs.')
+ parser.add_argument(
+ '--eval_data',
+ type=str,
+ default='test',
+ help="""Either 'test' or 'train_eval'.""")
+ parser.add_argument(
+ '--checkpoint_dir',
+ type=str,
+ default='/tmp/cifar10_train',
+ help="""Directory where to read model checkpoints.""")
+ parser.add_argument(
+ '--eval_interval_secs',
+ type=int,
+ default=60 * 5,
+ help='How often to run the eval.')
+ parser.add_argument(
+ '--num_examples',
+ type=int,
+ default=10000,
+ help='Number of examples to run.')
+ parser.add_argument(
+ '--run_once',
+ type=bool,
+ default=False,
+ help='Whether to run eval only once.')
+
+ FLAGS, unparsed = parser.parse_known_args()
+ tf.app.run(main=main, argv=[sys.argv[0]] + unparsed)
diff --git a/tensorflow/contrib/model_pruning/examples/cifar10/cifar10_input.py b/tensorflow/contrib/model_pruning/examples/cifar10/cifar10_input.py
new file mode 100644
index 0000000000..d07fece4bc
--- /dev/null
+++ b/tensorflow/contrib/model_pruning/examples/cifar10/cifar10_input.py
@@ -0,0 +1,256 @@
+# Copyright 2017 The TensorFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+"""Routine for decoding the CIFAR-10 binary file format."""
+
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+import os
+
+from six.moves import xrange # pylint: disable=redefined-builtin
+import tensorflow as tf
+
+# Process images of this size. Note that this differs from the original CIFAR
+# image size of 32 x 32. If one alters this number, then the entire model
+# architecture will change and any model would need to be retrained.
+IMAGE_SIZE = 24
+
+# Global constants describing the CIFAR-10 data set.
+NUM_CLASSES = 10
+NUM_EXAMPLES_PER_EPOCH_FOR_TRAIN = 50000
+NUM_EXAMPLES_PER_EPOCH_FOR_EVAL = 10000
+
+
+def read_cifar10(filename_queue):
+ """Reads and parses examples from CIFAR10 data files.
+
+ Recommendation: if you want N-way read parallelism, call this function
+ N times. This will give you N independent Readers reading different
+ files & positions within those files, which will give better mixing of
+ examples.
+
+ Args:
+ filename_queue: A queue of strings with the filenames to read from.
+
+ Returns:
+ An object representing a single example, with the following fields:
+ height: number of rows in the result (32)
+ width: number of columns in the result (32)
+ depth: number of color channels in the result (3)
+ key: a scalar string Tensor describing the filename & record number
+ for this example.
+ label: an int32 Tensor with the label in the range 0..9.
+ uint8image: a [height, width, depth] uint8 Tensor with the image data
+ """
+
+ class CIFAR10Record(object):
+ pass
+ result = CIFAR10Record()
+
+ # Dimensions of the images in the CIFAR-10 dataset.
+ # See http://www.cs.toronto.edu/~kriz/cifar.html for a description of the
+ # input format.
+ label_bytes = 1 # 2 for CIFAR-100
+ result.height = 32
+ result.width = 32
+ result.depth = 3
+ image_bytes = result.height * result.width * result.depth
+ # Every record consists of a label followed by the image, with a
+ # fixed number of bytes for each.
+ record_bytes = label_bytes + image_bytes
+
+ # Read a record, getting filenames from the filename_queue. No
+ # header or footer in the CIFAR-10 format, so we leave header_bytes
+ # and footer_bytes at their default of 0.
+ reader = tf.FixedLengthRecordReader(record_bytes=record_bytes)
+ result.key, value = reader.read(filename_queue)
+
+ # Convert from a string to a vector of uint8 that is record_bytes long.
+ record_bytes = tf.decode_raw(value, tf.uint8)
+
+ # The first bytes represent the label, which we convert from uint8->int32.
+ result.label = tf.cast(
+ tf.strided_slice(record_bytes, [0], [label_bytes]), tf.int32)
+
+ # The remaining bytes after the label represent the image, which we reshape
+ # from [depth * height * width] to [depth, height, width].
+ depth_major = tf.reshape(
+ tf.strided_slice(record_bytes, [label_bytes],
+ [label_bytes + image_bytes]),
+ [result.depth, result.height, result.width])
+ # Convert from [depth, height, width] to [height, width, depth].
+ result.uint8image = tf.transpose(depth_major, [1, 2, 0])
+
+ return result
+
+
+def _generate_image_and_label_batch(image, label, min_queue_examples,
+ batch_size, shuffle):
+ """Construct a queued batch of images and labels.
+
+ Args:
+ image: 3-D Tensor of [height, width, 3] of type.float32.
+ label: 1-D Tensor of type.int32
+ min_queue_examples: int32, minimum number of samples to retain
+ in the queue that provides of batches of examples.
+ batch_size: Number of images per batch.
+ shuffle: boolean indicating whether to use a shuffling queue.
+
+ Returns:
+ images: Images. 4D tensor of [batch_size, height, width, 3] size.
+ labels: Labels. 1D tensor of [batch_size] size.
+ """
+ # Create a queue that shuffles the examples, and then
+ # read 'batch_size' images + labels from the example queue.
+ num_preprocess_threads = 16
+ if shuffle:
+ images, label_batch = tf.train.shuffle_batch(
+ [image, label],
+ batch_size=batch_size,
+ num_threads=num_preprocess_threads,
+ capacity=min_queue_examples + 3 * batch_size,
+ min_after_dequeue=min_queue_examples)
+ else:
+ images, label_batch = tf.train.batch(
+ [image, label],
+ batch_size=batch_size,
+ num_threads=num_preprocess_threads,
+ capacity=min_queue_examples + 3 * batch_size)
+
+ # Display the training images in the visualizer.
+ tf.summary.image('images', images)
+
+ return images, tf.reshape(label_batch, [batch_size])
+
+
+def distorted_inputs(data_dir, batch_size):
+ """Construct distorted input for CIFAR training using the Reader ops.
+
+ Args:
+ data_dir: Path to the CIFAR-10 data directory.
+ batch_size: Number of images per batch.
+
+ Returns:
+ images: Images. 4D tensor of [batch_size, IMAGE_SIZE, IMAGE_SIZE, 3] size.
+ labels: Labels. 1D tensor of [batch_size] size.
+ """
+ filenames = [os.path.join(data_dir, 'data_batch_%d.bin' % i)
+ for i in xrange(1, 6)]
+ for f in filenames:
+ if not tf.gfile.Exists(f):
+ raise ValueError('Failed to find file: ' + f)
+
+ # Create a queue that produces the filenames to read.
+ filename_queue = tf.train.string_input_producer(filenames)
+
+ # Read examples from files in the filename queue.
+ read_input = read_cifar10(filename_queue)
+ reshaped_image = tf.cast(read_input.uint8image, tf.float32)
+
+ height = IMAGE_SIZE
+ width = IMAGE_SIZE
+
+ # Image processing for training the network. Note the many random
+ # distortions applied to the image.
+
+ # Randomly crop a [height, width] section of the image.
+ distorted_image = tf.random_crop(reshaped_image, [height, width, 3])
+
+ # Randomly flip the image horizontally.
+ distorted_image = tf.image.random_flip_left_right(distorted_image)
+
+ # Because these operations are not commutative, consider randomizing
+ # the order their operation.
+ distorted_image = tf.image.random_brightness(distorted_image,
+ max_delta=63)
+ distorted_image = tf.image.random_contrast(distorted_image,
+ lower=0.2, upper=1.8)
+
+ # Subtract off the mean and divide by the variance of the pixels.
+ float_image = tf.image.per_image_standardization(distorted_image)
+
+ # Set the shapes of tensors.
+ float_image.set_shape([height, width, 3])
+ read_input.label.set_shape([1])
+
+ # Ensure that the random shuffling has good mixing properties.
+ min_fraction_of_examples_in_queue = 0.4
+ min_queue_examples = int(NUM_EXAMPLES_PER_EPOCH_FOR_TRAIN *
+ min_fraction_of_examples_in_queue)
+ print ('Filling queue with %d CIFAR images before starting to train. '
+ 'This will take a few minutes.' % min_queue_examples)
+
+ # Generate a batch of images and labels by building up a queue of examples.
+ return _generate_image_and_label_batch(float_image, read_input.label,
+ min_queue_examples, batch_size,
+ shuffle=True)
+
+
+def inputs(eval_data, data_dir, batch_size):
+ """Construct input for CIFAR evaluation using the Reader ops.
+
+ Args:
+ eval_data: bool, indicating if one should use the train or eval data set.
+ data_dir: Path to the CIFAR-10 data directory.
+ batch_size: Number of images per batch.
+
+ Returns:
+ images: Images. 4D tensor of [batch_size, IMAGE_SIZE, IMAGE_SIZE, 3] size.
+ labels: Labels. 1D tensor of [batch_size] size.
+ """
+ if not eval_data:
+ filenames = [os.path.join(data_dir, 'data_batch_%d.bin' % i)
+ for i in xrange(1, 6)]
+ num_examples_per_epoch = NUM_EXAMPLES_PER_EPOCH_FOR_TRAIN
+ else:
+ filenames = [os.path.join(data_dir, 'test_batch.bin')]
+ num_examples_per_epoch = NUM_EXAMPLES_PER_EPOCH_FOR_EVAL
+
+ for f in filenames:
+ if not tf.gfile.Exists(f):
+ raise ValueError('Failed to find file: ' + f)
+
+ # Create a queue that produces the filenames to read.
+ filename_queue = tf.train.string_input_producer(filenames)
+
+ # Read examples from files in the filename queue.
+ read_input = read_cifar10(filename_queue)
+ reshaped_image = tf.cast(read_input.uint8image, tf.float32)
+
+ height = IMAGE_SIZE
+ width = IMAGE_SIZE
+
+ # Image processing for evaluation.
+ # Crop the central [height, width] of the image.
+ resized_image = tf.image.resize_image_with_crop_or_pad(reshaped_image,
+ width, height)
+
+ # Subtract off the mean and divide by the variance of the pixels.
+ float_image = tf.image.per_image_standardization(resized_image)
+
+ # Set the shapes of tensors.
+ float_image.set_shape([height, width, 3])
+ read_input.label.set_shape([1])
+
+ # Ensure that the random shuffling has good mixing properties.
+ min_fraction_of_examples_in_queue = 0.4
+ min_queue_examples = int(num_examples_per_epoch *
+ min_fraction_of_examples_in_queue)
+
+ # Generate a batch of images and labels by building up a queue of examples.
+ return _generate_image_and_label_batch(float_image, read_input.label,
+ min_queue_examples, batch_size,
+ shuffle=False)
diff --git a/tensorflow/contrib/model_pruning/examples/cifar10/cifar10_pruning.py b/tensorflow/contrib/model_pruning/examples/cifar10/cifar10_pruning.py
new file mode 100644
index 0000000000..0d1de869f6
--- /dev/null
+++ b/tensorflow/contrib/model_pruning/examples/cifar10/cifar10_pruning.py
@@ -0,0 +1,395 @@
+# Copyright 2017 The TensorFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+"""Builds the CIFAR-10 network with additional variables to support pruning.
+
+Summary of available functions:
+
+ # Compute input images and labels for training. If you would like to run
+ # evaluations, use inputs() instead.
+ inputs, labels = distorted_inputs()
+
+ # Compute inference on the model inputs to make a prediction.
+ predictions = inference(inputs)
+
+ # Compute the total loss of the prediction with respect to the labels.
+ loss = loss(predictions, labels)
+
+ # Create a graph to run one step of training with respect to the loss.
+ train_op = train(loss, global_step)
+"""
+# pylint: disable=missing-docstring
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+import os
+import re
+import sys
+import tarfile
+
+from six.moves import urllib
+import tensorflow as tf
+
+from tensorflow.contrib.model_pruning.examples.cifar10 import cifar10_input
+from tensorflow.contrib.model_pruning.python import pruning
+
+# Global constants describing the CIFAR-10 data set.
+IMAGE_SIZE = cifar10_input.IMAGE_SIZE
+NUM_CLASSES = cifar10_input.NUM_CLASSES
+NUM_EXAMPLES_PER_EPOCH_FOR_TRAIN = cifar10_input.NUM_EXAMPLES_PER_EPOCH_FOR_TRAIN
+NUM_EXAMPLES_PER_EPOCH_FOR_EVAL = cifar10_input.NUM_EXAMPLES_PER_EPOCH_FOR_EVAL
+BATCH_SIZE = 128
+DATA_DIR = '/tmp/cifar10_data'
+
+# Constants describing the training process.
+MOVING_AVERAGE_DECAY = 0.9999 # The decay to use for the moving average.
+NUM_EPOCHS_PER_DECAY = 350.0 # Epochs after which learning rate decays.
+LEARNING_RATE_DECAY_FACTOR = 0.1 # Learning rate decay factor.
+INITIAL_LEARNING_RATE = 0.1 # Initial learning rate.
+
+# If a model is trained with multiple GPUs, prefix all Op names with tower_name
+# to differentiate the operations. Note that this prefix is removed from the
+# names of the summaries when visualizing a model.
+TOWER_NAME = 'tower'
+
+DATA_URL = 'http://www.cs.toronto.edu/~kriz/cifar-10-binary.tar.gz'
+
+
+def _activation_summary(x):
+ """Helper to create summaries for activations.
+
+ Creates a summary that provides a histogram of activations.
+ Creates a summary that measures the sparsity of activations.
+
+ Args:
+ x: Tensor
+ Returns:
+ nothing
+ """
+ # Remove 'tower_[0-9]/' from the name in case this is a multi-GPU training
+ # session. This helps the clarity of presentation on tensorboard.
+ tensor_name = re.sub('%s_[0-9]*/' % TOWER_NAME, '', x.op.name)
+ tf.summary.histogram(tensor_name + '/activations', x)
+ tf.summary.scalar(tensor_name + '/sparsity',
+ tf.nn.zero_fraction(x))
+
+
+def _variable_on_cpu(name, shape, initializer):
+ """Helper to create a Variable stored on CPU memory.
+
+ Args:
+ name: name of the variable
+ shape: list of ints
+ initializer: initializer for Variable
+
+ Returns:
+ Variable Tensor
+ """
+ with tf.device('/cpu:0'):
+ dtype = tf.float32
+ var = tf.get_variable(name, shape, initializer=initializer, dtype=dtype)
+ return var
+
+
+def _variable_with_weight_decay(name, shape, stddev, wd):
+ """Helper to create an initialized Variable with weight decay.
+
+ Note that the Variable is initialized with a truncated normal distribution.
+ A weight decay is added only if one is specified.
+
+ Args:
+ name: name of the variable
+ shape: list of ints
+ stddev: standard deviation of a truncated Gaussian
+ wd: add L2Loss weight decay multiplied by this float. If None, weight
+ decay is not added for this Variable.
+
+ Returns:
+ Variable Tensor
+ """
+ dtype = tf.float32
+ var = _variable_on_cpu(
+ name,
+ shape,
+ tf.truncated_normal_initializer(stddev=stddev, dtype=dtype))
+ if wd is not None:
+ weight_decay = tf.multiply(tf.nn.l2_loss(var), wd, name='weight_loss')
+ tf.add_to_collection('losses', weight_decay)
+ return var
+
+
+def distorted_inputs():
+ """Construct distorted input for CIFAR training using the Reader ops.
+
+ Returns:
+ images: Images. 4D tensor of [batch_size, IMAGE_SIZE, IMAGE_SIZE, 3] size.
+ labels: Labels. 1D tensor of [batch_size] size.
+
+ Raises:
+ ValueError: If no data_dir
+ """
+ if not DATA_DIR:
+ raise ValueError('Please supply a data_dir')
+ data_dir = os.path.join(DATA_DIR, 'cifar-10-batches-bin')
+ images, labels = cifar10_input.distorted_inputs(
+ data_dir=data_dir, batch_size=BATCH_SIZE)
+ return images, labels
+
+
+def inputs(eval_data):
+ """Construct input for CIFAR evaluation using the Reader ops.
+
+ Args:
+ eval_data: bool, indicating if one should use the train or eval data set.
+
+ Returns:
+ images: Images. 4D tensor of [batch_size, IMAGE_SIZE, IMAGE_SIZE, 3] size.
+ labels: Labels. 1D tensor of [batch_size] size.
+
+ Raises:
+ ValueError: If no data_dir
+ """
+ if not DATA_DIR:
+ raise ValueError('Please supply a data_dir')
+ data_dir = os.path.join(DATA_DIR, 'cifar-10-batches-bin')
+ images, labels = cifar10_input.inputs(
+ eval_data=eval_data, data_dir=data_dir, batch_size=BATCH_SIZE)
+ return images, labels
+
+
+def inference(images):
+ """Build the CIFAR-10 model.
+
+ Args:
+ images: Images returned from distorted_inputs() or inputs().
+
+ Returns:
+ Logits.
+ """
+ # We instantiate all variables using tf.get_variable() instead of
+ # tf.Variable() in order to share variables across multiple GPU training runs.
+ # If we only ran this model on a single GPU, we could simplify this function
+ # by replacing all instances of tf.get_variable() with tf.Variable().
+ #
+ # While instantiating conv and local layers, we add mask and threshold
+ # variables to the layer by calling the pruning.apply_mask() function.
+ # Note that the masks are applied only to the weight tensors
+ # conv1
+ with tf.variable_scope('conv1') as scope:
+ kernel = _variable_with_weight_decay('weights',
+ shape=[5, 5, 3, 64],
+ stddev=5e-2,
+ wd=0.0)
+
+ conv = tf.nn.conv2d(
+ images, pruning.apply_mask(kernel, scope), [1, 1, 1, 1], padding='SAME')
+ biases = _variable_on_cpu('biases', [64], tf.constant_initializer(0.0))
+ pre_activation = tf.nn.bias_add(conv, biases)
+ conv1 = tf.nn.relu(pre_activation, name=scope.name)
+ _activation_summary(conv1)
+
+ # pool1
+ pool1 = tf.nn.max_pool(conv1, ksize=[1, 3, 3, 1], strides=[1, 2, 2, 1],
+ padding='SAME', name='pool1')
+ # norm1
+ norm1 = tf.nn.lrn(pool1, 4, bias=1.0, alpha=0.001 / 9.0, beta=0.75,
+ name='norm1')
+
+ # conv2
+ with tf.variable_scope('conv2') as scope:
+ kernel = _variable_with_weight_decay('weights',
+ shape=[5, 5, 64, 64],
+ stddev=5e-2,
+ wd=0.0)
+ conv = tf.nn.conv2d(
+ norm1, pruning.apply_mask(kernel, scope), [1, 1, 1, 1], padding='SAME')
+ biases = _variable_on_cpu('biases', [64], tf.constant_initializer(0.1))
+ pre_activation = tf.nn.bias_add(conv, biases)
+ conv2 = tf.nn.relu(pre_activation, name=scope.name)
+ _activation_summary(conv2)
+
+ # norm2
+ norm2 = tf.nn.lrn(conv2, 4, bias=1.0, alpha=0.001 / 9.0, beta=0.75,
+ name='norm2')
+ # pool2
+ pool2 = tf.nn.max_pool(norm2, ksize=[1, 3, 3, 1],
+ strides=[1, 2, 2, 1], padding='SAME', name='pool2')
+
+ # local3
+ with tf.variable_scope('local3') as scope:
+ # Move everything into depth so we can perform a single matrix multiply.
+ reshape = tf.reshape(pool2, [BATCH_SIZE, -1])
+ dim = reshape.get_shape()[1].value
+ weights = _variable_with_weight_decay('weights', shape=[dim, 384],
+ stddev=0.04, wd=0.004)
+ biases = _variable_on_cpu('biases', [384], tf.constant_initializer(0.1))
+ local3 = tf.nn.relu(
+ tf.matmul(reshape, pruning.apply_mask(weights, scope)) + biases,
+ name=scope.name)
+ _activation_summary(local3)
+
+ # local4
+ with tf.variable_scope('local4') as scope:
+ weights = _variable_with_weight_decay('weights', shape=[384, 192],
+ stddev=0.04, wd=0.004)
+ biases = _variable_on_cpu('biases', [192], tf.constant_initializer(0.1))
+ local4 = tf.nn.relu(
+ tf.matmul(local3, pruning.apply_mask(weights, scope)) + biases,
+ name=scope.name)
+ _activation_summary(local4)
+
+ # linear layer(WX + b),
+ # We don't apply softmax here because
+ # tf.nn.sparse_softmax_cross_entropy_with_logits accepts the unscaled logits
+ # and performs the softmax internally for efficiency.
+ with tf.variable_scope('softmax_linear') as scope:
+ weights = _variable_with_weight_decay('weights', [192, NUM_CLASSES],
+ stddev=1/192.0, wd=0.0)
+ biases = _variable_on_cpu('biases', [NUM_CLASSES],
+ tf.constant_initializer(0.0))
+ softmax_linear = tf.add(
+ tf.matmul(local4, pruning.apply_mask(weights, scope)),
+ biases,
+ name=scope.name)
+ _activation_summary(softmax_linear)
+
+ return softmax_linear
+
+
+def loss(logits, labels):
+ """Add L2Loss to all the trainable variables.
+
+ Add summary for "Loss" and "Loss/avg".
+ Args:
+ logits: Logits from inference().
+ labels: Labels from distorted_inputs or inputs(). 1-D tensor
+ of shape [batch_size]
+
+ Returns:
+ Loss tensor of type float.
+ """
+ # Calculate the average cross entropy loss across the batch.
+ labels = tf.cast(labels, tf.int64)
+ cross_entropy = tf.nn.sparse_softmax_cross_entropy_with_logits(
+ labels=labels, logits=logits, name='cross_entropy_per_example')
+ cross_entropy_mean = tf.reduce_mean(cross_entropy, name='cross_entropy')
+ tf.add_to_collection('losses', cross_entropy_mean)
+
+ # The total loss is defined as the cross entropy loss plus all of the weight
+ # decay terms (L2 loss).
+ return tf.add_n(tf.get_collection('losses'), name='total_loss')
+
+
+def _add_loss_summaries(total_loss):
+ """Add summaries for losses in CIFAR-10 model.
+
+ Generates moving average for all losses and associated summaries for
+ visualizing the performance of the network.
+
+ Args:
+ total_loss: Total loss from loss().
+ Returns:
+ loss_averages_op: op for generating moving averages of losses.
+ """
+ # Compute the moving average of all individual losses and the total loss.
+ loss_averages = tf.train.ExponentialMovingAverage(0.9, name='avg')
+ losses = tf.get_collection('losses')
+ loss_averages_op = loss_averages.apply(losses + [total_loss])
+
+ # Attach a scalar summary to all individual losses and the total loss; do the
+ # same for the averaged version of the losses.
+ for l in losses + [total_loss]:
+ # Name each loss as '(raw)' and name the moving average version of the loss
+ # as the original loss name.
+ tf.summary.scalar(l.op.name + ' (raw)', l)
+ tf.summary.scalar(l.op.name, loss_averages.average(l))
+
+ return loss_averages_op
+
+
+def train(total_loss, global_step):
+ """Train CIFAR-10 model.
+
+ Create an optimizer and apply to all trainable variables. Add moving
+ average for all trainable variables.
+
+ Args:
+ total_loss: Total loss from loss().
+ global_step: Integer Variable counting the number of training steps
+ processed.
+ Returns:
+ train_op: op for training.
+ """
+ # Variables that affect learning rate.
+ num_batches_per_epoch = NUM_EXAMPLES_PER_EPOCH_FOR_TRAIN / BATCH_SIZE
+ decay_steps = int(num_batches_per_epoch * NUM_EPOCHS_PER_DECAY)
+
+ # Decay the learning rate exponentially based on the number of steps.
+ lr = tf.train.exponential_decay(INITIAL_LEARNING_RATE,
+ global_step,
+ decay_steps,
+ LEARNING_RATE_DECAY_FACTOR,
+ staircase=True)
+ tf.summary.scalar('learning_rate', lr)
+
+ # Generate moving averages of all losses and associated summaries.
+ loss_averages_op = _add_loss_summaries(total_loss)
+
+ # Compute gradients.
+ with tf.control_dependencies([loss_averages_op]):
+ opt = tf.train.GradientDescentOptimizer(lr)
+ grads = opt.compute_gradients(total_loss)
+
+ # Apply gradients.
+ apply_gradient_op = opt.apply_gradients(grads, global_step=global_step)
+
+ # Add histograms for trainable variables.
+ for var in tf.trainable_variables():
+ tf.summary.histogram(var.op.name, var)
+
+ # Add histograms for gradients.
+ for grad, var in grads:
+ if grad is not None:
+ tf.summary.histogram(var.op.name + '/gradients', grad)
+
+ # Track the moving averages of all trainable variables.
+ variable_averages = tf.train.ExponentialMovingAverage(
+ MOVING_AVERAGE_DECAY, global_step)
+ variables_averages_op = variable_averages.apply(tf.trainable_variables())
+
+ with tf.control_dependencies([apply_gradient_op, variables_averages_op]):
+ train_op = tf.no_op(name='train')
+
+ return train_op
+
+
+def maybe_download_and_extract():
+ """Download and extract the tarball from Alex's website."""
+ dest_directory = DATA_DIR
+ if not os.path.exists(dest_directory):
+ os.makedirs(dest_directory)
+ filename = DATA_URL.split('/')[-1]
+ filepath = os.path.join(dest_directory, filename)
+ if not os.path.exists(filepath):
+ def _progress(count, block_size, total_size):
+ sys.stdout.write('\r>> Downloading %s %.1f%%' % (filename,
+ float(count * block_size) / float(total_size) * 100.0))
+ sys.stdout.flush()
+ filepath, _ = urllib.request.urlretrieve(DATA_URL, filepath, _progress)
+ print()
+ statinfo = os.stat(filepath)
+ print('Successfully downloaded', filename, statinfo.st_size, 'bytes.')
+
+ tarfile.open(filepath, 'r:gz').extractall(dest_directory)
diff --git a/tensorflow/contrib/model_pruning/examples/cifar10/cifar10_train.py b/tensorflow/contrib/model_pruning/examples/cifar10/cifar10_train.py
new file mode 100644
index 0000000000..a1064a3b6a
--- /dev/null
+++ b/tensorflow/contrib/model_pruning/examples/cifar10/cifar10_train.py
@@ -0,0 +1,159 @@
+# Copyright 2017 The TensorFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+"""A binary to train pruned CIFAR-10 using a single GPU.
+
+Accuracy:
+cifar10_train.py achieves ~86% accuracy after 100K steps (256 epochs of
+data) as judged by cifar10_eval.py when target sparsity in
+cifar10_pruning_spec.pbtxt is set to zero
+
+Results:
+Sparsity | Accuracy after 150K steps
+-------- | -------------------------
+0% | 86%
+50% | 86%
+75% | TODO(suyoggupta)
+90% | TODO(suyoggupta)
+95% | 77%
+
+Usage:
+Please see the tutorial and website for how to download the CIFAR-10
+data set, compile the program and train the model.
+
+
+"""
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+import argparse
+import datetime
+import sys
+import time
+
+
+import tensorflow as tf
+
+from tensorflow.contrib.model_pruning.examples.cifar10 import cifar10_pruning as cifar10
+from tensorflow.contrib.model_pruning.python import pruning
+
+FLAGS = None
+
+
+def train():
+ """Train CIFAR-10 for a number of steps."""
+ with tf.Graph().as_default():
+ global_step = tf.contrib.framework.get_or_create_global_step()
+
+ # Get images and labels for CIFAR-10.
+ images, labels = cifar10.distorted_inputs()
+
+ # Build a Graph that computes the logits predictions from the
+ # inference model.
+ logits = cifar10.inference(images)
+
+ # Calculate loss.
+ loss = cifar10.loss(logits, labels)
+
+ # Build a Graph that trains the model with one batch of examples and
+ # updates the model parameters.
+ train_op = cifar10.train(loss, global_step)
+
+ # Parse pruning hyperparameters
+ pruning_hparams = pruning.get_pruning_hparams().parse(FLAGS.pruning_hparams)
+
+ # Create a pruning object using the pruning hyperparameters
+ pruning_obj = pruning.Pruning(pruning_hparams, global_step=global_step)
+
+ # Use the pruning_obj to add ops to the training graph to update the masks
+ # The conditional_mask_update_op will update the masks only when the
+ # training step is in [begin_pruning_step, end_pruning_step] specified in
+ # the pruning spec proto
+ mask_update_op = pruning_obj.conditional_mask_update_op()
+
+ # Use the pruning_obj to add summaries to the graph to track the sparsity
+ # of each of the layers
+ pruning_obj.add_pruning_summaries()
+
+ class _LoggerHook(tf.train.SessionRunHook):
+ """Logs loss and runtime."""
+
+ def begin(self):
+ self._step = -1
+
+ def before_run(self, run_context):
+ self._step += 1
+ self._start_time = time.time()
+ return tf.train.SessionRunArgs(loss) # Asks for loss value.
+
+ def after_run(self, run_context, run_values):
+ duration = time.time() - self._start_time
+ loss_value = run_values.results
+ if self._step % 10 == 0:
+ num_examples_per_step = 128
+ examples_per_sec = num_examples_per_step / duration
+ sec_per_batch = float(duration)
+
+ format_str = ('%s: step %d, loss = %.2f (%.1f examples/sec; %.3f '
+ 'sec/batch)')
+ print(format_str % (datetime.datetime.now(), self._step, loss_value,
+ examples_per_sec, sec_per_batch))
+
+ with tf.train.MonitoredTrainingSession(
+ checkpoint_dir=FLAGS.train_dir,
+ hooks=[tf.train.StopAtStepHook(last_step=FLAGS.max_steps),
+ tf.train.NanTensorHook(loss),
+ _LoggerHook()],
+ config=tf.ConfigProto(
+ log_device_placement=FLAGS.log_device_placement)) as mon_sess:
+ while not mon_sess.should_stop():
+ mon_sess.run(train_op)
+ # Update the masks
+ mon_sess.run(mask_update_op)
+
+
+def main(argv=None): # pylint: disable=unused-argument
+ cifar10.maybe_download_and_extract()
+ if tf.gfile.Exists(FLAGS.train_dir):
+ tf.gfile.DeleteRecursively(FLAGS.train_dir)
+ tf.gfile.MakeDirs(FLAGS.train_dir)
+ train()
+
+
+if __name__ == '__main__':
+ parser = argparse.ArgumentParser()
+ parser.add_argument(
+ '--train_dir',
+ type=str,
+ default='/tmp/cifar10_train',
+ help='Directory where to write event logs and checkpoint.')
+ parser.add_argument(
+ '--pruning_hparams',
+ type=str,
+ default='',
+ help="""Comma separated list of pruning-related hyperparameters""")
+ parser.add_argument(
+ '--max_steps',
+ type=int,
+ default=1000000,
+ help='Number of batches to run.')
+ parser.add_argument(
+ '--log_device_placement',
+ type=bool,
+ default=False,
+ help='Whether to log device placement.')
+
+ FLAGS, unparsed = parser.parse_known_args()
+ tf.app.run(main=main, argv=[sys.argv[0]] + unparsed)
diff --git a/tensorflow/contrib/model_pruning/python/layers/core_layers.py b/tensorflow/contrib/model_pruning/python/layers/core_layers.py
new file mode 100644
index 0000000000..ae60d8b1e1
--- /dev/null
+++ b/tensorflow/contrib/model_pruning/python/layers/core_layers.py
@@ -0,0 +1,477 @@
+# Copyright 2017 The TensorFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+"""Contains the core layer classes for model pruning and its functional aliases.
+"""
+
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+from tensorflow.python.framework import ops
+from tensorflow.python.framework import tensor_shape
+from tensorflow.python.layers import base
+from tensorflow.python.layers import utils
+from tensorflow.python.ops import array_ops
+from tensorflow.python.ops import init_ops
+from tensorflow.python.ops import math_ops
+from tensorflow.python.ops import nn
+from tensorflow.python.ops import standard_ops
+
+MASK_COLLECTION = 'masks'
+THRESHOLD_COLLECTION = 'thresholds'
+MASKED_WEIGHT_COLLECTION = 'masked_weights'
+WEIGHT_COLLECTION = 'kernel'
+# The 'weights' part of the name is needed for the quantization library
+# to recognize that the kernel should be quantized.
+MASKED_WEIGHT_NAME = 'weights/masked_weight'
+
+
+class _MaskedConv(base.Layer):
+ """Abstract nD convolution layer (private, used as implementation base).
+
+ This layer creates a convolution kernel that is convolved
+ (actually cross-correlated) with the layer input to produce a tensor of
+ outputs. The weight tensor of this layer is masked.
+ If `use_bias` is True (and a `bias_initializer` is provided),
+ a bias vector is created and added to the outputs. Finally, if
+ `activation` is not `None`, it is applied to the outputs as well.
+
+ Arguments:
+ rank: An integer, the rank of the convolution, e.g. "2" for 2D convolution.
+ filters: Integer, the dimensionality of the output space (i.e. the number
+ of filters in the convolution).
+ kernel_size: An integer or tuple/list of n integers, specifying the
+ length of the convolution window.
+ strides: An integer or tuple/list of n integers,
+ specifying the stride length of the convolution.
+ Specifying any stride value != 1 is incompatible with specifying
+ any `dilation_rate` value != 1.
+ padding: One of `"valid"` or `"same"` (case-insensitive).
+ data_format: A string, one of `channels_last` (default) or `channels_first`.
+ The ordering of the dimensions in the inputs.
+ `channels_last` corresponds to inputs with shape
+ `(batch, ..., channels)` while `channels_first` corresponds to
+ inputs with shape `(batch, channels, ...)`.
+ dilation_rate: An integer or tuple/list of n integers, specifying
+ the dilation rate to use for dilated convolution.
+ Currently, specifying any `dilation_rate` value != 1 is
+ incompatible with specifying any `strides` value != 1.
+ activation: Activation function. Set it to None to maintain a
+ linear activation.
+ use_bias: Boolean, whether the layer uses a bias.
+ kernel_initializer: An initializer for the convolution kernel.
+ bias_initializer: An initializer for the bias vector. If None, no bias will
+ be applied.
+ kernel_regularizer: Optional regularizer for the convolution kernel.
+ bias_regularizer: Optional regularizer for the bias vector.
+ activity_regularizer: Regularizer function for the output.
+ trainable: Boolean, if `True` also add variables to the graph collection
+ `GraphKeys.TRAINABLE_VARIABLES` (see `tf.Variable`).
+ name: A string, the name of the layer.
+ """
+
+ def __init__(self,
+ rank,
+ filters,
+ kernel_size,
+ strides=1,
+ padding='valid',
+ data_format='channels_last',
+ dilation_rate=1,
+ activation=None,
+ use_bias=True,
+ kernel_initializer=None,
+ bias_initializer=init_ops.zeros_initializer(),
+ kernel_regularizer=None,
+ bias_regularizer=None,
+ activity_regularizer=None,
+ trainable=True,
+ name=None,
+ **kwargs):
+ super(_MaskedConv, self).__init__(
+ trainable=trainable,
+ name=name,
+ activity_regularizer=activity_regularizer,
+ **kwargs)
+ self.rank = rank
+ self.filters = filters
+ self.kernel_size = utils.normalize_tuple(kernel_size, rank, 'kernel_size')
+ self.strides = utils.normalize_tuple(strides, rank, 'strides')
+ self.padding = utils.normalize_padding(padding)
+ self.data_format = utils.normalize_data_format(data_format)
+ self.dilation_rate = utils.normalize_tuple(dilation_rate, rank,
+ 'dilation_rate')
+ self.activation = activation
+ self.use_bias = use_bias
+ self.kernel_initializer = kernel_initializer
+ self.bias_initializer = bias_initializer
+ self.kernel_regularizer = kernel_regularizer
+ self.bias_regularizer = bias_regularizer
+ self.input_spec = base.InputSpec(ndim=self.rank + 2)
+
+ def build(self, input_shape):
+ input_shape = tensor_shape.TensorShape(input_shape)
+ channel_axis = 1 if self.data_format == 'channels_first' else -1
+ if input_shape[channel_axis].value is None:
+ raise ValueError('The channel dimension of the inputs '
+ 'should be defined. Found `None`.')
+ input_dim = input_shape[channel_axis].value
+ kernel_shape = self.kernel_size + (input_dim, self.filters)
+ self.mask = self.add_variable(
+ name='mask',
+ shape=kernel_shape,
+ initializer=init_ops.ones_initializer(),
+ trainable=False,
+ dtype=self.dtype)
+
+ self.kernel = self.add_variable(
+ name='kernel',
+ shape=kernel_shape,
+ initializer=self.kernel_initializer,
+ regularizer=self.kernel_regularizer,
+ trainable=True,
+ dtype=self.dtype)
+
+ self.threshold = self.add_variable(
+ name='threshold',
+ shape=[],
+ initializer=init_ops.zeros_initializer(),
+ trainable=False,
+ dtype=self.dtype)
+
+ # Add masked_weights in the weights namescope so as to make it easier
+ # for the quantization library to add quant ops.
+ self.masked_kernel = math_ops.multiply(self.mask, self.kernel,
+ MASKED_WEIGHT_NAME)
+
+ ops.add_to_collection(MASK_COLLECTION, self.mask)
+ ops.add_to_collection(MASKED_WEIGHT_COLLECTION, self.masked_kernel)
+ ops.add_to_collection(THRESHOLD_COLLECTION, self.threshold)
+ ops.add_to_collection(WEIGHT_COLLECTION, self.kernel)
+
+ if self.use_bias:
+ self.bias = self.add_variable(
+ name='bias',
+ shape=(self.filters,),
+ initializer=self.bias_initializer,
+ regularizer=self.bias_regularizer,
+ trainable=True,
+ dtype=self.dtype)
+ else:
+ self.bias = None
+ self.input_spec = base.InputSpec(
+ ndim=self.rank + 2, axes={channel_axis: input_dim})
+ self.built = True
+
+ def call(self, inputs):
+ outputs = nn.convolution(
+ input=inputs,
+ filter=self.masked_kernel,
+ dilation_rate=self.dilation_rate,
+ strides=self.strides,
+ padding=self.padding.upper(),
+ data_format=utils.convert_data_format(self.data_format, self.rank + 2))
+
+ if self.bias is not None:
+ if self.data_format == 'channels_first':
+ if self.rank == 1:
+ # nn.bias_add does not accept a 1D input tensor.
+ bias = array_ops.reshape(self.bias, (1, self.filters, 1))
+ outputs += bias
+ if self.rank == 2:
+ outputs = nn.bias_add(outputs, self.bias, data_format='NCHW')
+ if self.rank == 3:
+ # As of Mar 2017, direct addition is significantly slower than
+ # bias_add when computing gradients. To use bias_add, we collapse Z
+ # and Y into a single dimension to obtain a 4D input tensor.
+ outputs_shape = outputs.shape.as_list()
+ outputs_4d = array_ops.reshape(outputs, [
+ outputs_shape[0], outputs_shape[1],
+ outputs_shape[2] * outputs_shape[3], outputs_shape[4]
+ ])
+ outputs_4d = nn.bias_add(outputs_4d, self.bias, data_format='NCHW')
+ outputs = array_ops.reshape(outputs_4d, outputs_shape)
+ else:
+ outputs = nn.bias_add(outputs, self.bias, data_format='NHWC')
+
+ if self.activation is not None:
+ return self.activation(outputs)
+ return outputs
+
+ def _compute_output_shape(self, input_shape):
+ input_shape = tensor_shape.TensorShape(input_shape).as_list()
+ if self.data_format == 'channels_last':
+ space = input_shape[1:-1]
+ new_space = []
+ for i in range(len(space)):
+ new_dim = utils.conv_output_length(
+ space[i],
+ self.kernel_size[i],
+ padding=self.padding,
+ stride=self.strides[i],
+ dilation=self.dilation_rate[i])
+ new_space.append(new_dim)
+ return tensor_shape.TensorShape([input_shape[0]] + new_space +
+ [self.filters])
+ else:
+ space = input_shape[2:]
+ new_space = []
+ for i in range(len(space)):
+ new_dim = utils.conv_output_length(
+ space[i],
+ self.kernel_size[i],
+ padding=self.padding,
+ stride=self.strides[i],
+ dilation=self.dilation_rate[i])
+ new_space.append(new_dim)
+ return tensor_shape.TensorShape([input_shape[0], self.filters] +
+ new_space)
+
+
+class MaskedConv2D(_MaskedConv):
+ """2D convolution layer (e.g. spatial convolution over images).
+
+ This layer creates a convolution kernel that is convolved
+ (actually cross-correlated) with the layer input to produce a tensor of
+ outputs. If `use_bias` is True (and a `bias_initializer` is provided),
+ a bias vector is created and added to the outputs. Finally, if
+ `activation` is not `None`, it is applied to the outputs as well.
+
+ Arguments:
+ filters: Integer, the dimensionality of the output space (i.e. the number
+ of filters in the convolution).
+ kernel_size: An integer or tuple/list of 2 integers, specifying the
+ height and width of the 2D convolution window.
+ Can be a single integer to specify the same value for
+ all spatial dimensions.
+ strides: An integer or tuple/list of 2 integers,
+ specifying the strides of the convolution along the height and width.
+ Can be a single integer to specify the same value for
+ all spatial dimensions.
+ Specifying any stride value != 1 is incompatible with specifying
+ any `dilation_rate` value != 1.
+ padding: One of `"valid"` or `"same"` (case-insensitive).
+ data_format: A string, one of `channels_last` (default) or `channels_first`.
+ The ordering of the dimensions in the inputs.
+ `channels_last` corresponds to inputs with shape
+ `(batch, height, width, channels)` while `channels_first` corresponds to
+ inputs with shape `(batch, channels, height, width)`.
+
+ dilation_rate: An integer or tuple/list of 2 integers, specifying
+ the dilation rate to use for dilated convolution.
+ Can be a single integer to specify the same value for
+ all spatial dimensions.
+ Currently, specifying any `dilation_rate` value != 1 is
+ incompatible with specifying any stride value != 1.
+ activation: Activation function. Set it to None to maintain a
+ linear activation.
+ use_bias: Boolean, whether the layer uses a bias.
+ kernel_initializer: An initializer for the convolution kernel.
+ bias_initializer: An initializer for the bias vector. If None, no bias will
+ be applied.
+ kernel_regularizer: Optional regularizer for the convolution kernel.
+ bias_regularizer: Optional regularizer for the bias vector.
+ activity_regularizer: Regularizer function for the output.
+ trainable: Boolean, if `True` also add variables to the graph collection
+ `GraphKeys.TRAINABLE_VARIABLES` (see `tf.Variable`).
+ name: A string, the name of the layer.
+ """
+
+ def __init__(self,
+ filters,
+ kernel_size,
+ strides=(1, 1),
+ padding='valid',
+ data_format='channels_last',
+ dilation_rate=(1, 1),
+ activation=None,
+ use_bias=True,
+ kernel_initializer=None,
+ bias_initializer=init_ops.zeros_initializer(),
+ kernel_regularizer=None,
+ bias_regularizer=None,
+ activity_regularizer=None,
+ trainable=True,
+ name=None,
+ **kwargs):
+ super(MaskedConv2D, self).__init__(
+ rank=2,
+ filters=filters,
+ kernel_size=kernel_size,
+ strides=strides,
+ padding=padding,
+ data_format=data_format,
+ dilation_rate=dilation_rate,
+ activation=activation,
+ use_bias=use_bias,
+ kernel_initializer=kernel_initializer,
+ bias_initializer=bias_initializer,
+ kernel_regularizer=kernel_regularizer,
+ bias_regularizer=bias_regularizer,
+ activity_regularizer=activity_regularizer,
+ trainable=trainable,
+ name=name,
+ **kwargs)
+
+
+class MaskedFullyConnected(base.Layer):
+ """Fully-connected layer class with masked weights.
+
+ This layer implements the operation:
+ `outputs = activation(inputs.kernel + bias)`
+ Where `activation` is the activation function passed as the `activation`
+ argument (if not `None`), `kernel` is a weights matrix created by the layer,
+ and `bias` is a bias vector created by the layer
+ (only if `use_bias` is `True`).
+
+ Note: if the input to the layer has a rank greater than 2, then it is
+ flattened prior to the initial matrix multiply by `kernel`.
+
+ Arguments:
+ units: Integer or Long, dimensionality of the output space.
+ activation: Activation function (callable). Set it to None to maintain a
+ linear activation.
+ use_bias: Boolean, whether the layer uses a bias.
+ kernel_initializer: Initializer function for the weight matrix.
+ bias_initializer: Initializer function for the bias.
+ kernel_regularizer: Regularizer function for the weight matrix.
+ bias_regularizer: Regularizer function for the bias.
+ activity_regularizer: Regularizer function for the output.
+ trainable: Boolean, if `True` also add variables to the graph collection
+ `GraphKeys.TRAINABLE_VARIABLES` (see `tf.Variable`).
+ name: String, the name of the layer. Layers with the same name will
+ share weights, but to avoid mistakes we require reuse=True in such cases.
+ reuse: Boolean, whether to reuse the weights of a previous layer
+ by the same name.
+
+ Properties:
+ units: Python integer, dimensionality of the output space.
+ activation: Activation function (callable).
+ use_bias: Boolean, whether the layer uses a bias.
+ kernel_initializer: Initializer instance (or name) for the weight matrix.
+ bias_initializer: Initializer instance (or name) for the bias.
+ kernel_regularizer: Regularizer instance for the weight matrix (callable)
+ bias_regularizer: Regularizer instance for the bias (callable).
+ activity_regularizer: Regularizer instance for the output (callable)
+ kernel: Weight matrix (TensorFlow variable or tensor).
+ bias: Bias vector, if applicable (TensorFlow variable or tensor).
+ """
+
+ def __init__(self,
+ units,
+ activation=None,
+ use_bias=True,
+ kernel_initializer=None,
+ bias_initializer=init_ops.zeros_initializer(),
+ kernel_regularizer=None,
+ bias_regularizer=None,
+ activity_regularizer=None,
+ trainable=True,
+ name=None,
+ **kwargs):
+ super(MaskedFullyConnected, self).__init__(
+ trainable=trainable,
+ name=name,
+ activity_regularizer=activity_regularizer,
+ **kwargs)
+ self.units = units
+ self.activation = activation
+ self.use_bias = use_bias
+ self.kernel_initializer = kernel_initializer
+ self.bias_initializer = bias_initializer
+ self.kernel_regularizer = kernel_regularizer
+ self.bias_regularizer = bias_regularizer
+ self.input_spec = base.InputSpec(min_ndim=2)
+
+ def build(self, input_shape):
+ input_shape = tensor_shape.TensorShape(input_shape)
+ if input_shape[-1].value is None:
+ raise ValueError('The last dimension of the inputs to `Dense` '
+ 'should be defined. Found `None`.')
+ self.input_spec = base.InputSpec(
+ min_ndim=2, axes={-1: input_shape[-1].value})
+
+ self.kernel = self.add_variable(
+ 'kernel',
+ shape=[input_shape[-1].value, self.units],
+ initializer=self.kernel_initializer,
+ regularizer=self.kernel_regularizer,
+ dtype=self.dtype,
+ trainable=True)
+
+ self.mask = self.add_variable(
+ name='mask',
+ shape=[input_shape[-1].value, self.units],
+ initializer=init_ops.ones_initializer(),
+ trainable=False,
+ dtype=self.dtype)
+
+ self.threshold = self.add_variable(
+ name='threshold',
+ shape=[],
+ initializer=init_ops.zeros_initializer(),
+ trainable=False,
+ dtype=self.dtype)
+
+ # Add masked_weights in the weights namescope so as to make it easier
+ # for the quantization library to add quant ops.
+ self.masked_kernel = math_ops.multiply(self.mask, self.kernel,
+ MASKED_WEIGHT_NAME)
+
+ ops.add_to_collection(MASK_COLLECTION, self.mask)
+ ops.add_to_collection(MASKED_WEIGHT_COLLECTION, self.masked_kernel)
+ ops.add_to_collection(THRESHOLD_COLLECTION, self.threshold)
+ ops.add_to_collection(WEIGHT_COLLECTION, self.kernel)
+
+ if self.use_bias:
+ self.bias = self.add_variable(
+ 'bias',
+ shape=[
+ self.units,
+ ],
+ initializer=self.bias_initializer,
+ regularizer=self.bias_regularizer,
+ dtype=self.dtype,
+ trainable=True)
+ else:
+ self.bias = None
+ self.built = True
+
+ def call(self, inputs):
+ inputs = ops.convert_to_tensor(inputs, dtype=self.dtype)
+ shape = inputs.get_shape().as_list()
+ output_shape = shape[:-1] + [self.units]
+ if len(output_shape) > 2:
+ # Broadcasting is required for the inputs.
+ outputs = standard_ops.tensordot(inputs, self.masked_kernel,
+ [[len(shape) - 1], [0]])
+ # Reshape the output back to the original ndim of the input.
+ outputs.set_shape(output_shape)
+ else:
+ outputs = standard_ops.matmul(inputs, self.masked_kernel)
+ if self.use_bias:
+ outputs = nn.bias_add(outputs, self.bias)
+ if self.activation is not None:
+ return self.activation(outputs) # pylint: disable=not-callable
+ return outputs
+
+ def _compute_output_shape(self, input_shape):
+ input_shape = tensor_shape.TensorShape(input_shape)
+ input_shape = input_shape.with_rank_at_least(2)
+ if input_shape[-1].value is None:
+ raise ValueError(
+ 'The innermost dimension of input_shape must be defined, but saw: %s'
+ % input_shape)
+ return input_shape[:-1].concatenate(self.units)
diff --git a/tensorflow/contrib/model_pruning/python/layers/layers.py b/tensorflow/contrib/model_pruning/python/layers/layers.py
new file mode 100644
index 0000000000..dfebb9a679
--- /dev/null
+++ b/tensorflow/contrib/model_pruning/python/layers/layers.py
@@ -0,0 +1,364 @@
+# Copyright 2017 The TensorFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+"""Tensorflow layers with added variables for parameter masking.
+
+Branched from tensorflow/contrib/layers/python/layers/layers.py
+"""
+# pylint: disable=missing-docstring
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+import numpy as np
+import six
+
+from tensorflow.contrib.framework.python.ops import add_arg_scope
+from tensorflow.contrib.framework.python.ops import variables
+from tensorflow.contrib.layers.python.layers import initializers
+from tensorflow.contrib.layers.python.layers import utils
+from tensorflow.contrib.model_pruning.python.layers import core_layers as core
+from tensorflow.python.framework import ops
+from tensorflow.python.ops import init_ops
+from tensorflow.python.ops import nn
+from tensorflow.python.ops import variable_scope
+from tensorflow.python.ops import variables as tf_variables
+
+
+def _model_variable_getter(getter,
+ name,
+ shape=None,
+ dtype=None,
+ initializer=None,
+ regularizer=None,
+ trainable=True,
+ collections=None,
+ caching_device=None,
+ partitioner=None,
+ rename=None,
+ use_resource=None,
+ **_):
+ """Getter that uses model_variable for compatibility with core layers."""
+ short_name = name.split('/')[-1]
+ if rename and short_name in rename:
+ name_components = name.split('/')
+ name_components[-1] = rename[short_name]
+ name = '/'.join(name_components)
+ return variables.model_variable(
+ name,
+ shape=shape,
+ dtype=dtype,
+ initializer=initializer,
+ regularizer=regularizer,
+ collections=collections,
+ trainable=trainable,
+ caching_device=caching_device,
+ partitioner=partitioner,
+ custom_getter=getter,
+ use_resource=use_resource)
+
+
+def _build_variable_getter(rename=None):
+ """Build a model variable getter that respects scope getter and renames."""
+
+ # VariableScope will nest the getters
+ def layer_variable_getter(getter, *args, **kwargs):
+ kwargs['rename'] = rename
+ return _model_variable_getter(getter, *args, **kwargs)
+
+ return layer_variable_getter
+
+
+def _add_variable_to_collections(variable, collections_set, collections_name):
+ """Adds variable (or all its parts) to all collections with that name."""
+ collections = utils.get_variable_collections(collections_set,
+ collections_name) or []
+ variables_list = [variable]
+ if isinstance(variable, tf_variables.PartitionedVariable):
+ variables_list = [v for v in variable]
+ for collection in collections:
+ for var in variables_list:
+ if var not in ops.get_collection(collection):
+ ops.add_to_collection(collection, var)
+
+
+@add_arg_scope
+def masked_convolution(inputs,
+ num_outputs,
+ kernel_size,
+ stride=1,
+ padding='SAME',
+ data_format=None,
+ rate=1,
+ activation_fn=nn.relu,
+ normalizer_fn=None,
+ normalizer_params=None,
+ weights_initializer=initializers.xavier_initializer(),
+ weights_regularizer=None,
+ biases_initializer=init_ops.zeros_initializer(),
+ biases_regularizer=None,
+ reuse=None,
+ variables_collections=None,
+ outputs_collections=None,
+ trainable=True,
+ scope=None):
+ """Adds an 2D convolution followed by an optional batch_norm layer.
+ The layer creates a mask variable on top of the weight variable. The input to
+ the convolution operation is the elementwise multiplication of the mask
+ variable and the weigh
+
+ It is required that 1 <= N <= 3.
+
+ `convolution` creates a variable called `weights`, representing the
+ convolutional kernel, that is convolved (actually cross-correlated) with the
+ `inputs` to produce a `Tensor` of activations. If a `normalizer_fn` is
+ provided (such as `batch_norm`), it is then applied. Otherwise, if
+ `normalizer_fn` is None and a `biases_initializer` is provided then a `biases`
+ variable would be created and added the activations. Finally, if
+ `activation_fn` is not `None`, it is applied to the activations as well.
+
+ Performs atrous convolution with input stride/dilation rate equal to `rate`
+ if a value > 1 for any dimension of `rate` is specified. In this case
+ `stride` values != 1 are not supported.
+
+ Args:
+ inputs: A Tensor of rank N+2 of shape
+ `[batch_size] + input_spatial_shape + [in_channels]` if data_format does
+ not start with "NC" (default), or
+ `[batch_size, in_channels] + input_spatial_shape` if data_format starts
+ with "NC".
+ num_outputs: Integer, the number of output filters.
+ kernel_size: A sequence of N positive integers specifying the spatial
+ dimensions of of the filters. Can be a single integer to specify the same
+ value for all spatial dimensions.
+ stride: A sequence of N positive integers specifying the stride at which to
+ compute output. Can be a single integer to specify the same value for all
+ spatial dimensions. Specifying any `stride` value != 1 is incompatible
+ with specifying any `rate` value != 1.
+ padding: One of `"VALID"` or `"SAME"`.
+ data_format: A string or None. Specifies whether the channel dimension of
+ the `input` and output is the last dimension (default, or if `data_format`
+ does not start with "NC"), or the second dimension (if `data_format`
+ starts with "NC"). For N=1, the valid values are "NWC" (default) and
+ "NCW". For N=2, the valid values are "NHWC" (default) and "NCHW".
+ For N=3, the valid values are "NDHWC" (default) and "NCDHW".
+ rate: A sequence of N positive integers specifying the dilation rate to use
+ for atrous convolution. Can be a single integer to specify the same
+ value for all spatial dimensions. Specifying any `rate` value != 1 is
+ incompatible with specifying any `stride` value != 1.
+ activation_fn: Activation function. The default value is a ReLU function.
+ Explicitly set it to None to skip it and maintain a linear activation.
+ normalizer_fn: Normalization function to use instead of `biases`. If
+ `normalizer_fn` is provided then `biases_initializer` and
+ `biases_regularizer` are ignored and `biases` are not created nor added.
+ default set to None for no normalizer function
+ normalizer_params: Normalization function parameters.
+ weights_initializer: An initializer for the weights.
+ weights_regularizer: Optional regularizer for the weights.
+ biases_initializer: An initializer for the biases. If None skip biases.
+ biases_regularizer: Optional regularizer for the biases.
+ reuse: Whether or not the layer and its variables should be reused. To be
+ able to reuse the layer scope must be given.
+ variables_collections: Optional list of collections for all the variables or
+ a dictionary containing a different list of collection per variable.
+ outputs_collections: Collection to add the outputs.
+ trainable: If `True` also add variables to the graph collection
+ `GraphKeys.TRAINABLE_VARIABLES` (see tf.Variable).
+ scope: Optional scope for `variable_scope`.
+
+ Returns:
+ A tensor representing the output of the operation.
+
+ Raises:
+ ValueError: If `data_format` is invalid.
+ ValueError: Both 'rate' and `stride` are not uniformly 1.
+ """
+ if data_format not in [None, 'NWC', 'NCW', 'NHWC', 'NCHW', 'NDHWC', 'NCDHW']:
+ raise ValueError('Invalid data_format: %r' % (data_format,))
+
+ layer_variable_getter = _build_variable_getter({
+ 'bias': 'biases',
+ 'kernel': 'weights'
+ })
+
+ with variable_scope.variable_scope(
+ scope, 'Conv', [inputs], reuse=reuse,
+ custom_getter=layer_variable_getter) as sc:
+ inputs = ops.convert_to_tensor(inputs)
+ input_rank = inputs.get_shape().ndims
+
+ if input_rank == 3:
+ raise ValueError('Sparse Convolution not supported for input with rank',
+ input_rank)
+ elif input_rank == 4:
+ layer_class = core.MaskedConv2D
+ elif input_rank == 5:
+ raise ValueError('Sparse Convolution not supported for input with rank',
+ input_rank)
+ else:
+ raise ValueError('Sparse Convolution not supported for input with rank',
+ input_rank)
+
+ if data_format is None or data_format == 'NHWC':
+ df = 'channels_last'
+ elif data_format == 'NCHW':
+ df = 'channels_first'
+ else:
+ raise ValueError('Unsupported data fromat', data_format)
+
+ layer = layer_class(
+ filters=num_outputs,
+ kernel_size=kernel_size,
+ strides=stride,
+ padding=padding,
+ data_format=df,
+ dilation_rate=rate,
+ activation=None,
+ use_bias=not normalizer_fn and biases_initializer,
+ kernel_initializer=weights_initializer,
+ bias_initializer=biases_initializer,
+ kernel_regularizer=weights_regularizer,
+ bias_regularizer=biases_regularizer,
+ activity_regularizer=None,
+ trainable=trainable,
+ name=sc.name,
+ dtype=inputs.dtype.base_dtype,
+ _scope=sc,
+ _reuse=reuse)
+ outputs = layer.apply(inputs)
+
+ # Add variables to collections.
+ _add_variable_to_collections(layer.kernel, variables_collections, 'weights')
+ if layer.use_bias:
+ _add_variable_to_collections(layer.bias, variables_collections, 'biases')
+
+ if normalizer_fn is not None:
+ normalizer_params = normalizer_params or {}
+ outputs = normalizer_fn(outputs, **normalizer_params)
+
+ if activation_fn is not None:
+ outputs = activation_fn(outputs)
+ return utils.collect_named_outputs(outputs_collections,
+ sc.original_name_scope, outputs)
+
+
+masked_conv2d = masked_convolution
+
+
+@add_arg_scope
+def masked_fully_connected(
+ inputs,
+ num_outputs,
+ activation_fn=nn.relu,
+ normalizer_fn=None,
+ normalizer_params=None,
+ weights_initializer=initializers.xavier_initializer(),
+ weights_regularizer=None,
+ biases_initializer=init_ops.zeros_initializer(),
+ biases_regularizer=None,
+ reuse=None,
+ variables_collections=None,
+ outputs_collections=None,
+ trainable=True,
+ scope=None):
+ """Adds a sparse fully connected layer. The weight matrix is masked.
+
+ `fully_connected` creates a variable called `weights`, representing a fully
+ connected weight matrix, which is multiplied by the `inputs` to produce a
+ `Tensor` of hidden units. If a `normalizer_fn` is provided (such as
+ `batch_norm`), it is then applied. Otherwise, if `normalizer_fn` is
+ None and a `biases_initializer` is provided then a `biases` variable would be
+ created and added the hidden units. Finally, if `activation_fn` is not `None`,
+ it is applied to the hidden units as well.
+
+ Note: that if `inputs` have a rank greater than 2, then `inputs` is flattened
+ prior to the initial matrix multiply by `weights`.
+
+ Args:
+ inputs: A tensor of at least rank 2 and static value for the last dimension;
+ i.e. `[batch_size, depth]`, `[None, None, None, channels]`.
+ num_outputs: Integer or long, the number of output units in the layer.
+ activation_fn: Activation function. The default value is a ReLU function.
+ Explicitly set it to None to skip it and maintain a linear activation.
+ normalizer_fn: Normalization function to use instead of `biases`. If
+ `normalizer_fn` is provided then `biases_initializer` and
+ `biases_regularizer` are ignored and `biases` are not created nor added.
+ default set to None for no normalizer function
+ normalizer_params: Normalization function parameters.
+ weights_initializer: An initializer for the weights.
+ weights_regularizer: Optional regularizer for the weights.
+ biases_initializer: An initializer for the biases. If None skip biases.
+ biases_regularizer: Optional regularizer for the biases.
+ reuse: Whether or not the layer and its variables should be reused. To be
+ able to reuse the layer scope must be given.
+ variables_collections: Optional list of collections for all the variables or
+ a dictionary containing a different list of collections per variable.
+ outputs_collections: Collection to add the outputs.
+ trainable: If `True` also add variables to the graph collection
+ `GraphKeys.TRAINABLE_VARIABLES` (see tf.Variable).
+ scope: Optional scope for variable_scope.
+
+ Returns:
+ The tensor variable representing the result of the series of operations.
+
+ Raises:
+ ValueError: If x has rank less than 2 or if its last dimension is not set.
+ """
+ if not isinstance(num_outputs, six.integer_types):
+ raise ValueError('num_outputs should be int or long, got %s.' %
+ (num_outputs,))
+
+ layer_variable_getter = _build_variable_getter({
+ 'bias': 'biases',
+ 'kernel': 'weights'
+ })
+
+ with variable_scope.variable_scope(
+ scope,
+ 'fully_connected', [inputs],
+ reuse=reuse,
+ custom_getter=layer_variable_getter) as sc:
+ inputs = ops.convert_to_tensor(inputs)
+ layer = core.MaskedFullyConnected(
+ units=num_outputs,
+ activation=None,
+ use_bias=not normalizer_fn and biases_initializer,
+ kernel_initializer=weights_initializer,
+ bias_initializer=biases_initializer,
+ kernel_regularizer=weights_regularizer,
+ bias_regularizer=biases_regularizer,
+ activity_regularizer=None,
+ trainable=trainable,
+ name=sc.name,
+ dtype=inputs.dtype.base_dtype,
+ _scope=sc,
+ _reuse=reuse)
+ outputs = layer.apply(inputs)
+
+ # Add variables to collections.
+ _add_variable_to_collections(layer.kernel, variables_collections, 'weights')
+ if layer.bias is not None:
+ _add_variable_to_collections(layer.bias, variables_collections, 'biases')
+
+ # Apply normalizer function / layer.
+ if normalizer_fn is not None:
+ if not normalizer_params:
+ normalizer_params = {}
+ outputs = normalizer_fn(outputs, **normalizer_params)
+
+ if activation_fn is not None:
+ outputs = activation_fn(outputs)
+
+ return utils.collect_named_outputs(outputs_collections,
+ sc.original_name_scope, outputs)
diff --git a/tensorflow/contrib/model_pruning/python/layers/layers_test.py b/tensorflow/contrib/model_pruning/python/layers/layers_test.py
new file mode 100644
index 0000000000..97a2c97850
--- /dev/null
+++ b/tensorflow/contrib/model_pruning/python/layers/layers_test.py
@@ -0,0 +1,139 @@
+# Copyright 2017 The TensorFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+"""Tests for imagingvision.intelligence.tensorflow.model_pruning.layers."""
+
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+from tensorflow.contrib.model_pruning.python.layers import core_layers
+from tensorflow.contrib.model_pruning.python.layers import layers
+from tensorflow.python.framework import ops
+from tensorflow.python.ops import array_ops
+from tensorflow.python.platform import test
+
+
+class MaskedConvolutionLayerTest(test.TestCase):
+
+ def setUp(self):
+ super(MaskedConvolutionLayerTest, self).setUp()
+ self.height, self.width = 7, 9
+
+ def testInvalidRank3(self):
+ input_tensor = array_ops.ones((self.height, self.width, 3))
+ with self.assertRaisesRegexp(ValueError, 'rank'):
+ layers.masked_conv2d(input_tensor, 32, 3)
+
+ def testInvalidRank5(self):
+ input_tensor = array_ops.ones((8, 8, self.height, self.width, 3))
+ with self.assertRaisesRegexp(ValueError, 'rank'):
+ layers.masked_conv2d(input_tensor, 32, 3)
+
+ def testSingleConvMaskAdded(self):
+ kernel_size = 3
+ input_depth, output_depth = 8, 32
+ input_tensor = array_ops.ones((8, self.height, self.width, input_depth))
+ layers.masked_conv2d(input_tensor, output_depth, kernel_size)
+
+ masks = ops.get_collection(core_layers.MASK_COLLECTION)
+ self.assertEqual(len(masks), 1)
+ self.assertListEqual(masks[0].get_shape().as_list(),
+ [kernel_size, kernel_size, input_depth, output_depth])
+
+ masked_weight = ops.get_collection(core_layers.MASKED_WEIGHT_COLLECTION)
+ self.assertEqual(len(masked_weight), 1)
+ self.assertListEqual(masked_weight[0].get_shape().as_list(),
+ [kernel_size, kernel_size, input_depth, output_depth])
+
+ def testMultipleConvMaskAdded(self):
+ number_of_layers = 5
+
+ kernel_size = 3
+ base_depth = 4
+ depth_step = 7
+
+ input_tensor = array_ops.ones((8, self.height, self.width, base_depth))
+
+ top_layer = input_tensor
+
+ for ix in range(number_of_layers):
+ top_layer = layers.masked_conv2d(top_layer, base_depth +
+ (ix + 1) * depth_step, kernel_size)
+
+ masks = ops.get_collection(core_layers.MASK_COLLECTION)
+ self.assertEqual(len(masks), number_of_layers)
+ for ix in range(number_of_layers):
+ self.assertListEqual(masks[ix].get_shape().as_list(), [
+ kernel_size, kernel_size, base_depth + ix * depth_step,
+ base_depth + (ix + 1) * depth_step
+ ])
+
+ masked_weight = ops.get_collection(core_layers.MASKED_WEIGHT_COLLECTION)
+ self.assertEqual(len(masked_weight), number_of_layers)
+ for ix in range(number_of_layers):
+ self.assertListEqual(masked_weight[ix].get_shape().as_list(), [
+ kernel_size, kernel_size, base_depth + ix * depth_step,
+ base_depth + (ix + 1) * depth_step
+ ])
+
+
+class MaskedFullyConnectedLayerTest(test.TestCase):
+
+ def testSingleFCMaskAdded(self):
+ input_depth, output_depth = 8, 32
+ input_tensor = array_ops.ones((5, input_depth))
+ layers.masked_fully_connected(input_tensor, output_depth)
+
+ masks = ops.get_collection(core_layers.MASK_COLLECTION)
+ self.assertEqual(len(masks), 1)
+ self.assertListEqual(masks[0].get_shape().as_list(),
+ [input_depth, output_depth])
+
+ masked_weight = ops.get_collection(core_layers.MASKED_WEIGHT_COLLECTION)
+ self.assertEqual(len(masked_weight), 1)
+ self.assertListEqual(masked_weight[0].get_shape().as_list(),
+ [input_depth, output_depth])
+
+ def testMultipleConvMaskAdded(self):
+ number_of_layers = 5
+
+ base_depth = 4
+ depth_step = 7
+
+ input_tensor = array_ops.ones((8, base_depth))
+
+ top_layer = input_tensor
+
+ for ix in range(number_of_layers):
+ top_layer = layers.masked_fully_connected(top_layer, base_depth +
+ (ix + 1) * depth_step)
+
+ masks = ops.get_collection(core_layers.MASK_COLLECTION)
+ self.assertEqual(len(masks), number_of_layers)
+ for ix in range(number_of_layers):
+ self.assertListEqual(masks[ix].get_shape().as_list(), [
+ base_depth + ix * depth_step, base_depth + (ix + 1) * depth_step
+ ])
+
+ masked_weight = ops.get_collection(core_layers.MASKED_WEIGHT_COLLECTION)
+ self.assertEqual(len(masked_weight), number_of_layers)
+ for ix in range(number_of_layers):
+ self.assertListEqual(masked_weight[ix].get_shape().as_list(), [
+ base_depth + ix * depth_step, base_depth + (ix + 1) * depth_step
+ ])
+
+
+if __name__ == '__main__':
+ test.main()
diff --git a/tensorflow/contrib/model_pruning/python/layers/rnn_cells.py b/tensorflow/contrib/model_pruning/python/layers/rnn_cells.py
new file mode 100644
index 0000000000..18ba3d1327
--- /dev/null
+++ b/tensorflow/contrib/model_pruning/python/layers/rnn_cells.py
@@ -0,0 +1,340 @@
+# Copyright 2017 The TensorFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+"""Module implementing RNN Cells with pruning.
+
+This module implements BasicLSTMCell and LSTMCell with pruning.
+Code adapted from third_party/tensorflow/python/ops/rnn_cell_impl.py
+"""
+
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+from tensorflow.contrib.model_pruning.python.layers import core_layers
+from tensorflow.python.framework import constant_op
+from tensorflow.python.framework import dtypes
+from tensorflow.python.framework import ops
+from tensorflow.python.ops import array_ops
+from tensorflow.python.ops import clip_ops
+from tensorflow.python.ops import init_ops
+from tensorflow.python.ops import math_ops
+from tensorflow.python.ops import nn_ops
+from tensorflow.python.ops import rnn_cell as tf_rnn
+
+
+class MaskedBasicLSTMCell(tf_rnn.BasicLSTMCell):
+ """Basic LSTM recurrent network cell with pruning.
+
+ Overrides the call method of tensorflow BasicLSTMCell and injects the weight
+ masks
+
+ The implementation is based on: http://arxiv.org/abs/1409.2329.
+
+ We add forget_bias (default: 1) to the biases of the forget gate in order to
+ reduce the scale of forgetting in the beginning of the training.
+
+ It does not allow cell clipping, a projection layer, and does not
+ use peep-hole connections: it is the basic baseline.
+
+ For advanced models, please use the full @{tf.nn.rnn_cell.LSTMCell}
+ that follows.
+ """
+
+ def __init__(self,
+ num_units,
+ forget_bias=1.0,
+ state_is_tuple=True,
+ activation=None,
+ reuse=None,
+ name=None):
+ """Initialize the basic LSTM cell with pruning.
+
+ Args:
+ num_units: int, The number of units in the LSTM cell.
+ forget_bias: float, The bias added to forget gates (see above).
+ Must set to `0.0` manually when restoring from CudnnLSTM-trained
+ checkpoints.
+ state_is_tuple: If True, accepted and returned states are 2-tuples of
+ the `c_state` and `m_state`. If False, they are concatenated
+ along the column axis. The latter behavior will soon be deprecated.
+ activation: Activation function of the inner states. Default: `tanh`.
+ reuse: (optional) Python boolean describing whether to reuse variables
+ in an existing scope. If not `True`, and the existing scope already has
+ the given variables, an error is raised.
+ name: String, the name of the layer. Layers with the same name will
+ share weights, but to avoid mistakes we require reuse=True in such
+ cases.
+
+ When restoring from CudnnLSTM-trained checkpoints, must use
+ CudnnCompatibleLSTMCell instead.
+ """
+ super(MaskedBasicLSTMCell, self).__init__(
+ num_units,
+ forget_bias=forget_bias,
+ state_is_tuple=state_is_tuple,
+ activation=activation,
+ reuse=reuse,
+ name=name)
+
+ def build(self, inputs_shape):
+ # Call the build method of the parent class.
+ super(MaskedBasicLSTMCell, self).build(inputs_shape)
+
+ input_depth = inputs_shape[1].value
+ h_depth = self._num_units
+ self._mask = self.add_variable(
+ name="mask",
+ shape=[input_depth + h_depth, 4 * h_depth],
+ initializer=init_ops.ones_initializer(),
+ trainable=False,
+ dtype=self.dtype)
+ self._threshold = self.add_variable(
+ name="threshold",
+ shape=[],
+ initializer=init_ops.zeros_initializer(),
+ trainable=False,
+ dtype=self.dtype)
+ # Add masked_weights in the weights namescope so as to make it easier
+ # for the quantization library to add quant ops.
+ self._masked_kernel = math_ops.multiply(self._mask, self._kernel,
+ core_layers.MASKED_WEIGHT_NAME)
+ if self._mask not in ops.get_collection_ref(core_layers.MASK_COLLECTION):
+ ops.add_to_collection(core_layers.MASK_COLLECTION, self._mask)
+ ops.add_to_collection(core_layers.MASKED_WEIGHT_COLLECTION,
+ self._masked_kernel)
+ ops.add_to_collection(core_layers.THRESHOLD_COLLECTION, self._threshold)
+ ops.add_to_collection(core_layers.WEIGHT_COLLECTION, self._kernel)
+
+ def call(self, inputs, state):
+ """Long short-term memory cell (LSTM) with masks for pruning.
+
+ Args:
+ inputs: `2-D` tensor with shape `[batch_size, input_size]`.
+ state: An `LSTMStateTuple` of state tensors, each shaped
+ `[batch_size, self.state_size]`, if `state_is_tuple` has been set to
+ `True`. Otherwise, a `Tensor` shaped
+ `[batch_size, 2 * self.state_size]`.
+
+ Returns:
+ A pair containing the new hidden state, and the new state (either a
+ `LSTMStateTuple` or a concatenated state, depending on
+ `state_is_tuple`).
+ """
+ sigmoid = math_ops.sigmoid
+ one = constant_op.constant(1, dtype=dtypes.int32)
+ # Parameters of gates are concatenated into one multiply for efficiency.
+ if self._state_is_tuple:
+ c, h = state
+ else:
+ c, h = array_ops.split(value=state, num_or_size_splits=2, axis=one)
+
+ gate_inputs = math_ops.matmul(
+ array_ops.concat([inputs, h], 1), self._masked_kernel)
+ gate_inputs = nn_ops.bias_add(gate_inputs, self._bias)
+
+ # i = input_gate, j = new_input, f = forget_gate, o = output_gate
+ i, j, f, o = array_ops.split(
+ value=gate_inputs, num_or_size_splits=4, axis=one)
+
+ forget_bias_tensor = constant_op.constant(self._forget_bias, dtype=f.dtype)
+ # Note that using `add` and `multiply` instead of `+` and `*` gives a
+ # performance improvement. So using those at the cost of readability.
+ add = math_ops.add
+ multiply = math_ops.multiply
+ new_c = add(
+ multiply(c, sigmoid(add(f, forget_bias_tensor))),
+ multiply(sigmoid(i), self._activation(j)))
+ new_h = multiply(self._activation(new_c), sigmoid(o))
+
+ if self._state_is_tuple:
+ new_state = tf_rnn.LSTMStateTuple(new_c, new_h)
+ else:
+ new_state = array_ops.concat([new_c, new_h], 1)
+ return new_h, new_state
+
+
+class MaskedLSTMCell(tf_rnn.LSTMCell):
+ """LSTMCell with pruning.
+
+ Overrides the call method of tensorflow LSTMCell and injects the weight masks.
+ Masks are applied to only the weight matrix of the LSTM and not the
+ projection matrix.
+ """
+
+ def __init__(self,
+ num_units,
+ use_peepholes=False,
+ cell_clip=None,
+ initializer=None,
+ num_proj=None,
+ proj_clip=None,
+ num_unit_shards=None,
+ num_proj_shards=None,
+ forget_bias=1.0,
+ state_is_tuple=True,
+ activation=None,
+ reuse=None):
+ """Initialize the parameters for an LSTM cell with masks for pruning.
+
+ Args:
+ num_units: int, The number of units in the LSTM cell
+ use_peepholes: bool, set True to enable diagonal/peephole connections.
+ cell_clip: (optional) A float value, if provided the cell state is clipped
+ by this value prior to the cell output activation.
+ initializer: (optional) The initializer to use for the weight and
+ projection matrices.
+ num_proj: (optional) int, The output dimensionality for the projection
+ matrices. If None, no projection is performed.
+ proj_clip: (optional) A float value. If `num_proj > 0` and `proj_clip` is
+ provided, then the projected values are clipped elementwise to within
+ `[-proj_clip, proj_clip]`.
+ num_unit_shards: Deprecated, will be removed by Jan. 2017.
+ Use a variable_scope partitioner instead.
+ num_proj_shards: Deprecated, will be removed by Jan. 2017.
+ Use a variable_scope partitioner instead.
+ forget_bias: Biases of the forget gate are initialized by default to 1
+ in order to reduce the scale of forgetting at the beginning of
+ the training. Must set it manually to `0.0` when restoring from
+ CudnnLSTM trained checkpoints.
+ state_is_tuple: If True, accepted and returned states are 2-tuples of
+ the `c_state` and `m_state`. If False, they are concatenated
+ along the column axis. This latter behavior will soon be deprecated.
+ activation: Activation function of the inner states. Default: `tanh`.
+ reuse: (optional) Python boolean describing whether to reuse variables
+ in an existing scope. If not `True`, and the existing scope already has
+ the given variables, an error is raised.
+
+ When restoring from CudnnLSTM-trained checkpoints, must use
+ CudnnCompatibleLSTMCell instead.
+ """
+ super(MaskedLSTMCell, self).__init__(
+ num_units,
+ use_peepholes=use_peepholes,
+ cell_clip=cell_clip,
+ initializer=initializer,
+ num_proj=num_proj,
+ proj_clip=proj_clip,
+ num_unit_shards=num_unit_shards,
+ num_proj_shards=num_proj_shards,
+ forget_bias=forget_bias,
+ state_is_tuple=state_is_tuple,
+ activation=activation,
+ reuse=reuse)
+
+ def build(self, inputs_shape):
+ # Call the build method of the parent class.
+ super(MaskedLSTMCell, self).build(inputs_shape)
+
+ input_depth = inputs_shape[1].value
+ h_depth = self._num_units
+ self._mask = self.add_variable(
+ name="mask",
+ shape=[input_depth + h_depth, 4 * h_depth],
+ initializer=init_ops.ones_initializer(),
+ trainable=False,
+ dtype=self.dtype)
+ self._threshold = self.add_variable(
+ name="threshold",
+ shape=[],
+ initializer=init_ops.zeros_initializer(),
+ trainable=False,
+ dtype=self.dtype)
+ # Add masked_weights in the weights namescope so as to make it easier
+ # for the quantization library to add quant ops.
+ self._masked_kernel = math_ops.multiply(self._mask, self._kernel,
+ core_layers.MASKED_WEIGHT_NAME)
+ if self._mask not in ops.get_collection_ref(core_layers.MASK_COLLECTION):
+ ops.add_to_collection(core_layers.MASK_COLLECTION, self._mask)
+ ops.add_to_collection(core_layers.MASKED_WEIGHT_COLLECTION,
+ self._masked_kernel)
+ ops.add_to_collection(core_layers.THRESHOLD_COLLECTION, self._threshold)
+ ops.add_to_collection(core_layers.WEIGHT_COLLECTION, self._kernel)
+
+ def call(self, inputs, state):
+ """Run one step of LSTM.
+
+ Args:
+ inputs: input Tensor, 2D, `[batch, num_units].
+ state: if `state_is_tuple` is False, this must be a state Tensor,
+ `2-D, [batch, state_size]`. If `state_is_tuple` is True, this must be a
+ tuple of state Tensors, both `2-D`, with column sizes `c_state` and
+ `m_state`.
+
+ Returns:
+ A tuple containing:
+
+ - A `2-D, [batch, output_dim]`, Tensor representing the output of the
+ LSTM after reading `inputs` when previous state was `state`.
+ Here output_dim is:
+ num_proj if num_proj was set,
+ num_units otherwise.
+ - Tensor(s) representing the new state of LSTM after reading `inputs` when
+ the previous state was `state`. Same type and shape(s) as `state`.
+
+ Raises:
+ ValueError: If input size cannot be inferred from inputs via
+ static shape inference.
+ """
+ num_proj = self._num_units if self._num_proj is None else self._num_proj
+ sigmoid = math_ops.sigmoid
+
+ if self._state_is_tuple:
+ (c_prev, m_prev) = state
+ else:
+ c_prev = array_ops.slice(state, [0, 0], [-1, self._num_units])
+ m_prev = array_ops.slice(state, [0, self._num_units], [-1, num_proj])
+
+ input_size = inputs.get_shape().with_rank(2)[1]
+ if input_size.value is None:
+ raise ValueError("Could not infer input size from inputs.get_shape()[-1]")
+
+ # i = input_gate, j = new_input, f = forget_gate, o = output_gate
+ lstm_matrix = math_ops.matmul(
+ array_ops.concat([inputs, m_prev], 1), self._masked_kernel)
+ lstm_matrix = nn_ops.bias_add(lstm_matrix, self._bias)
+
+ i, j, f, o = array_ops.split(
+ value=lstm_matrix, num_or_size_splits=4, axis=1)
+ # Diagonal connections
+ if self._use_peepholes:
+ c = (
+ sigmoid(f + self._forget_bias + self._w_f_diag * c_prev) * c_prev +
+ sigmoid(i + self._w_i_diag * c_prev) * self._activation(j))
+ else:
+ c = (
+ sigmoid(f + self._forget_bias) * c_prev +
+ sigmoid(i) * self._activation(j))
+
+ if self._cell_clip is not None:
+ # pylint: disable=invalid-unary-operand-type
+ c = clip_ops.clip_by_value(c, -self._cell_clip, self._cell_clip)
+ # pylint: enable=invalid-unary-operand-type
+ if self._use_peepholes:
+ m = sigmoid(o + self._w_o_diag * c) * self._activation(c)
+ else:
+ m = sigmoid(o) * self._activation(c)
+
+ if self._num_proj is not None:
+ m = math_ops.matmul(m, self._proj_kernel)
+
+ if self._proj_clip is not None:
+ # pylint: disable=invalid-unary-operand-type
+ m = clip_ops.clip_by_value(m, -self._proj_clip, self._proj_clip)
+ # pylint: enable=invalid-unary-operand-type
+
+ new_state = (
+ tf_rnn.LSTMStateTuple(c, m)
+ if self._state_is_tuple else array_ops.concat([c, m], 1))
+ return m, new_state
diff --git a/tensorflow/contrib/model_pruning/python/layers/rnn_cells_test.py b/tensorflow/contrib/model_pruning/python/layers/rnn_cells_test.py
new file mode 100644
index 0000000000..e85ae7b22a
--- /dev/null
+++ b/tensorflow/contrib/model_pruning/python/layers/rnn_cells_test.py
@@ -0,0 +1,85 @@
+# Copyright 2017 The TensorFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+"""Tests for creating different number of masks in rnn_cells."""
+
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+from tensorflow.contrib.model_pruning.python import pruning
+from tensorflow.contrib.model_pruning.python.layers import rnn_cells
+from tensorflow.python.ops import random_ops
+from tensorflow.python.ops import rnn_cell as tf_rnn_cells
+from tensorflow.python.ops import variables
+from tensorflow.python.platform import test
+
+
+class RnnCellsTest(test.TestCase):
+
+ def setUp(self):
+ super(RnnCellsTest, self).setUp()
+ self.batch_size = 8
+ self.dim = 10
+
+ def testMaskedBasicLSTMCell(self):
+ expected_num_masks = 1
+ expected_num_rows = 2 * self.dim
+ expected_num_cols = 4 * self.dim
+ with self.test_session():
+ inputs = variables.Variable(
+ random_ops.random_normal([self.batch_size, self.dim]))
+ c = variables.Variable(
+ random_ops.random_normal([self.batch_size, self.dim]))
+ h = variables.Variable(
+ random_ops.random_normal([self.batch_size, self.dim]))
+ state = tf_rnn_cells.LSTMStateTuple(c, h)
+ lstm_cell = rnn_cells.MaskedBasicLSTMCell(self.dim)
+ lstm_cell(inputs, state)
+ self.assertEqual(len(pruning.get_masks()), expected_num_masks)
+ self.assertEqual(len(pruning.get_masked_weights()), expected_num_masks)
+ self.assertEqual(len(pruning.get_thresholds()), expected_num_masks)
+ self.assertEqual(len(pruning.get_weights()), expected_num_masks)
+
+ for mask in pruning.get_masks():
+ self.assertEqual(mask.shape, (expected_num_rows, expected_num_cols))
+ for weight in pruning.get_weights():
+ self.assertEqual(weight.shape, (expected_num_rows, expected_num_cols))
+
+ def testMaskedLSTMCell(self):
+ expected_num_masks = 1
+ expected_num_rows = 2 * self.dim
+ expected_num_cols = 4 * self.dim
+ with self.test_session():
+ inputs = variables.Variable(
+ random_ops.random_normal([self.batch_size, self.dim]))
+ c = variables.Variable(
+ random_ops.random_normal([self.batch_size, self.dim]))
+ h = variables.Variable(
+ random_ops.random_normal([self.batch_size, self.dim]))
+ state = tf_rnn_cells.LSTMStateTuple(c, h)
+ lstm_cell = rnn_cells.MaskedLSTMCell(self.dim)
+ lstm_cell(inputs, state)
+ self.assertEqual(len(pruning.get_masks()), expected_num_masks)
+ self.assertEqual(len(pruning.get_masked_weights()), expected_num_masks)
+ self.assertEqual(len(pruning.get_thresholds()), expected_num_masks)
+ self.assertEqual(len(pruning.get_weights()), expected_num_masks)
+
+ for mask in pruning.get_masks():
+ self.assertEqual(mask.shape, (expected_num_rows, expected_num_cols))
+ for weight in pruning.get_weights():
+ self.assertEqual(weight.shape, (expected_num_rows, expected_num_cols))
+
+if __name__ == '__main__':
+ test.main()
diff --git a/tensorflow/contrib/model_pruning/python/learning.py b/tensorflow/contrib/model_pruning/python/learning.py
new file mode 100644
index 0000000000..2b79c23cef
--- /dev/null
+++ b/tensorflow/contrib/model_pruning/python/learning.py
@@ -0,0 +1,188 @@
+# Copyright 2017 The TensorFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+"""Wrapper around tf-slim's training code contrib/slim/python/slim/learning.py
+to support training of pruned models
+
+*******************************************************************
+* A simple working training script with support for model pruning *
+*******************************************************************
+
+ # Load data and create the model:
+ images, labels = LoadData(...)
+ predictions = MyModel(images)
+
+ # Define the loss:
+ slim.losses.log_loss(predictions, labels)
+ total_loss = slim.losses.get_total_loss()
+
+ # Define the optimizer:
+ optimizer = tf.train.MomentumOptimizer(FLAGS.learning_rate, FLAGS.momentum)
+
+ # Create the train_op
+ train_op = slim.learning.create_train_op(total_loss, optimizer)
+
+ # Set up sparsity
+ sparsity = pruning.setup_gradual_sparsity(self.global_step)
+
+ # Create mask update op
+ mask_update_op = pruning.add_mask_update_ip(sparsity)
+
+ # Run training.
+ learning.train(train_op,
+ my_log_dir,
+ mask_update_op)
+ see contrib/slim/python/slim/learning.py for additional examples
+"""
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+from tensorflow.contrib import slim as _slim
+
+_USE_DEFAULT = 0
+train_step = _slim.learning.train_step
+
+
+def train(train_op,
+ logdir,
+ mask_update_op,
+ train_step_fn=train_step,
+ train_step_kwargs=_USE_DEFAULT,
+ log_every_n_steps=1,
+ graph=None,
+ master='',
+ is_chief=True,
+ global_step=None,
+ number_of_steps=None,
+ init_op=_USE_DEFAULT,
+ init_feed_dict=None,
+ local_init_op=_USE_DEFAULT,
+ init_fn=None,
+ ready_op=_USE_DEFAULT,
+ summary_op=_USE_DEFAULT,
+ save_summaries_secs=600,
+ summary_writer=_USE_DEFAULT,
+ startup_delay_steps=0,
+ saver=None,
+ save_interval_secs=600,
+ sync_optimizer=None,
+ session_config=None,
+ trace_every_n_steps=None):
+ """Wrapper around tf-slim's train function.
+
+ Runs a training loop using a TensorFlow supervisor.
+ When the sync_optimizer is supplied, gradient updates are applied
+ synchronously. Otherwise, gradient updates are applied asynchronous.
+
+ Args:
+ train_op: A `Tensor` that, when executed, will apply the gradients and
+ return the loss value.
+ logdir: The directory where training logs are written to. If None, model
+ checkpoints and summaries will not be written.
+ mask_update_op: Operation that upon execution updates the weight masks and
+ thresholds.
+ train_step_fn: The function to call in order to execute a single gradient
+ step. The function must have take exactly four arguments: the current
+ session, the `train_op` `Tensor`, a global step `Tensor` and a dictionary.
+ train_step_kwargs: A dictionary which is passed to the `train_step_fn`. By
+ default, two `Boolean`, scalar ops called "should_stop" and "should_log"
+ are provided.
+ log_every_n_steps: The frequency, in terms of global steps, that the loss
+ and global step and logged.
+ graph: The graph to pass to the supervisor. If no graph is supplied the
+ default graph is used.
+ master: The address of the tensorflow master.
+ is_chief: Specifies whether or not the training is being run by the primary
+ replica during replica training.
+ global_step: The `Tensor` representing the global step. If left as `None`,
+ then slim.variables.get_or_create_global_step() is used.
+ number_of_steps: The max number of gradient steps to take during training,
+ as measured by 'global_step': training will stop if global_step is
+ greater than 'number_of_steps'. If the value is left as None, training
+ proceeds indefinitely.
+ init_op: The initialization operation. If left to its default value, then
+ the session is initialized by calling `tf.global_variables_initializer()`.
+ init_feed_dict: A feed dictionary to use when executing the `init_op`.
+ local_init_op: The local initialization operation. If left to its default
+ value, then the session is initialized by calling
+ `tf.local_variables_initializer()` and `tf.tables_initializer()`.
+ init_fn: An optional callable to be executed after `init_op` is called. The
+ callable must accept one argument, the session being initialized.
+ ready_op: Operation to check if the model is ready to use. If left to its
+ default value, then the session checks for readiness by calling
+ `tf.report_uninitialized_variables()`.
+ summary_op: The summary operation.
+ save_summaries_secs: How often, in seconds, to save summaries.
+ summary_writer: `SummaryWriter` to use. Can be `None`
+ to indicate that no summaries should be written. If unset, we
+ create a SummaryWriter.
+ startup_delay_steps: The number of steps to wait for before beginning. Note
+ that this must be 0 if a sync_optimizer is supplied.
+ saver: Saver to save checkpoints. If None, a default one will be created
+ and used.
+ save_interval_secs: How often, in seconds, to save the model to `logdir`.
+ sync_optimizer: an instance of tf.train.SyncReplicasOptimizer, or a list of
+ them. If the argument is supplied, gradient updates will be synchronous.
+ If left as `None`, gradient updates will be asynchronous.
+ session_config: An instance of `tf.ConfigProto` that will be used to
+ configure the `Session`. If left as `None`, the default will be used.
+ trace_every_n_steps: produce and save a `Timeline` in Chrome trace format
+ and add it to the summaries every `trace_every_n_steps`. If None, no trace
+ information will be produced or saved.
+
+ Returns:
+ the value of the loss function after training.
+
+ Raises:
+ ValueError: if `train_op` is empty or if `startup_delay_steps` is
+ non-zero when `sync_optimizer` is supplied, if `number_of_steps` is
+ negative, or if `trace_every_n_steps` is not `None` and no `logdir` is
+ provided.
+ """
+
+ def train_step_with_pruning_fn(sess, train_op, global_step,
+ train_step_kwargs):
+ total_loss, should_stop = train_step_fn(sess, train_op, global_step,
+ train_step_kwargs)
+ sess.run(mask_update_op)
+ return total_loss, should_stop
+
+ total_loss, _ = _slim.learning.train(
+ train_op,
+ logdir,
+ train_step_fn=train_step_with_pruning_fn,
+ train_step_kwargs=train_step_kwargs,
+ log_every_n_steps=log_every_n_steps,
+ graph=graph,
+ master=master,
+ is_chief=is_chief,
+ global_step=global_step,
+ number_of_steps=number_of_steps,
+ init_op=init_op,
+ init_feed_dict=init_feed_dict,
+ local_init_op=local_init_op,
+ init_fn=init_fn,
+ ready_op=ready_op,
+ summary_op=summary_op,
+ save_summaries_secs=save_summaries_secs,
+ summary_writer=summary_writer,
+ startup_delay_steps=startup_delay_steps,
+ saver=saver,
+ save_interval_secs=save_interval_secs,
+ sync_optimizer=sync_optimizer,
+ session_config=session_config,
+ trace_every_n_steps=trace_every_n_steps)
+
+ return total_loss
diff --git a/tensorflow/contrib/model_pruning/python/pruning.py b/tensorflow/contrib/model_pruning/python/pruning.py
new file mode 100644
index 0000000000..42d91a71fd
--- /dev/null
+++ b/tensorflow/contrib/model_pruning/python/pruning.py
@@ -0,0 +1,585 @@
+# Copyright 2017 The TensorFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+"""Helper functions to add support for magnitude-based model pruning.
+
+ # Adds variables and ops to the graph to enable
+ # elementwise masking of weights
+ apply_mask(weights)
+
+ # Returns a list containing the sparsity of each of the weight tensors
+ get_weight_sparsity()
+
+ # Returns a list of all the masked weight tensorflow variables
+ get_masked_weights()
+
+ # Returns a list of all the mask tensorflow variables
+ get_masks()
+
+ # Returns a list of all the thresholds
+ get_thresholds()
+
+ # Returns a list of all the weight tensors that have been masked
+ get_weights()
+
+ The Pruning class uses a proto (defined in pruning.proto) to set up the
+ parameters for a pruning specification. Here's a typical usage:
+
+ # Initialize a pruning spec from a proto
+ pruning_spec = '/tmp/pruning.pb'
+ p = Pruning(pruning_spec)
+
+ # Add mask update ops to the graph
+ mask_update_op = p.conditional_mask_update_op()
+
+ # Add the summaries
+ p.add_pruning_summaries()
+
+ # Run the op
+ session.run(mask_update_op)
+
+ # An object of the pruning also accepts externally defined sparsity:
+ sparsity = tf.Variable(0.5, name = "ConstantSparsity")
+ pruning_spec = '/tmp/pruning.pb'
+ p = Pruning(pruning_spec, sparsity=sparsity)
+
+"""
+# pylint: disable=missing-docstring
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+import numpy as np
+
+from tensorflow.contrib.model_pruning.python.layers import core_layers as core
+from tensorflow.contrib.training.python.training import hparam
+from tensorflow.python.framework import ops
+from tensorflow.python.ops import array_ops
+from tensorflow.python.ops import clip_ops
+from tensorflow.python.ops import control_flow_ops
+from tensorflow.python.ops import gen_array_ops
+from tensorflow.python.ops import init_ops
+from tensorflow.python.ops import math_ops
+from tensorflow.python.ops import nn_impl
+from tensorflow.python.ops import state_ops
+from tensorflow.python.ops import variable_scope
+from tensorflow.python.platform import tf_logging as logging
+from tensorflow.python.summary import summary
+from tensorflow.python.training import training_util
+
+_MASK_COLLECTION = core.MASK_COLLECTION
+_THRESHOLD_COLLECTION = core.THRESHOLD_COLLECTION
+_MASKED_WEIGHT_COLLECTION = core.MASKED_WEIGHT_COLLECTION
+_WEIGHT_COLLECTION = core.WEIGHT_COLLECTION
+_MASKED_WEIGHT_NAME = core.MASKED_WEIGHT_NAME
+
+
+def _weight_mask_variable(var, scope):
+ """Create a mask for the weights.
+
+ This function adds a variable 'mask' to the graph.
+
+ Args:
+ var: the weight variable that needs to be masked
+ scope: The variable scope of the variable var
+
+ Returns:
+ the mask variable of the same size and shape as var, initialized to all 1s.
+ """
+ with variable_scope.variable_scope(scope):
+ mask = variable_scope.get_variable(
+ 'mask',
+ var.get_shape(),
+ initializer=init_ops.ones_initializer(),
+ trainable=False,
+ dtype=var.dtype)
+ return mask
+
+
+def _weight_threshold_variable(var, scope):
+ """Create a scalar threshold for the weights.
+
+ This function adds a variable
+ 'threshold' to the graph.
+
+ Args:
+ var: The weight variable that needs to be masked
+ scope: The variable scope of the variable var
+
+ Returns:
+ a scalar threshold variable initialized to 0.
+ """
+ with variable_scope.variable_scope(scope):
+ threshold = variable_scope.get_variable(
+ 'threshold', [],
+ initializer=init_ops.zeros_initializer(),
+ trainable=False,
+ dtype=var.dtype)
+ return threshold
+
+
+def _histogram(values, value_range, nbins=100, dtype=np.int32, name=None):
+ """Return histogram of values.
+
+ Given the tensor `values`, this operation returns a rank 1 histogram counting
+ the number of entries in `values` that fell into every bin. The bins are
+ equal width and determined by the arguments `value_range` and `nbins`.
+
+ Args:
+ values: Numeric `Tensor`.
+ value_range: Shape [2] `Tensor` of same `dtype` as `values`.
+ values <= value_range[0] will be mapped to hist[0],
+ values >= value_range[1] will be mapped to hist[-1].
+ nbins: Scalar `int32 Tensor`. Number of histogram bins.
+ dtype: dtype for returned histogram.
+ name: A name for this operation (defaults to 'histogram').
+
+ Returns:
+ A 1-D `Tensor` holding histogram of values.
+
+ """
+ with ops.name_scope(name, 'histogram', [values, value_range, nbins]) as scope:
+ values = ops.convert_to_tensor(values, name='values')
+ values = gen_array_ops.reshape(values, [-1])
+ value_range = ops.convert_to_tensor(value_range, name='value_range')
+ nbins = ops.convert_to_tensor(nbins, dtype=np.int32, name='nbins')
+ nbins_float = math_ops.cast(nbins, values.dtype)
+
+ # Map tensor values that fall within value_range to [0, 1].
+ scaled_values = math_ops.truediv(
+ values - value_range[0],
+ value_range[1] - value_range[0],
+ name='scaled_values')
+
+ # map tensor values within the open interval value_range to {0,.., nbins-1},
+ # values outside the open interval will be zero or less, or nbins or more.
+ indices = math_ops.floor(nbins_float * scaled_values, name='indices')
+
+ # Clip edge cases (e.g. value = value_range[1]) or "outliers."
+ indices = math_ops.cast(
+ clip_ops.clip_by_value(indices, 0, nbins_float - 1), np.int32)
+
+ return math_ops.unsorted_segment_sum(
+ array_ops.ones_like(indices, dtype=dtype), indices, nbins, name=scope)
+
+
+def _determine_partitioned_axis(partitioned_variable):
+ partitioned_axis = 0
+ concatenated_variable_shape = partitioned_variable.get_shape()
+ for partition in partitioned_variable:
+ partition_shape = partition.get_shape()
+ maybe_partitioned_axis = np.less(partition_shape,
+ concatenated_variable_shape)
+ # Sanity check: make sure number of partitioned axis == 1
+ if np.count_nonzero(maybe_partitioned_axis) != 1:
+ raise ValueError('Number of partitioned axes %s not equal to 1' %
+ np.count_nonzero(maybe_partitioned_axis))
+ partitioned_axis = np.where(maybe_partitioned_axis)[0][0]
+ return partitioned_axis
+
+
+def _variable_assign(var, new_value):
+ return state_ops.assign(var, new_value, name=var.op.name + '_assign')
+
+
+def _partitioned_variable_assign(partitioned_var, new_value):
+ """Assign op for partitioned variables.
+
+ Args:
+ partitioned_var: A partitioned tensotflow variable
+ new_value: Value to be assigned to the variable var
+
+ Returns:
+ A tensorflow op that groups the assign ops for each of the variable slices
+ """
+ # Determine which axis was used to partition the variable. Currently
+ # tensorflow allows partitioning variable only along 1 axis.
+ axis = 0 if len(partitioned_var) == 1 else _determine_partitioned_axis(
+ partitioned_var)
+
+ partition_sizes = np.array(
+ [partition.get_shape()[axis] for partition in partitioned_var])
+ new_partitioned_values = array_ops.split(
+ new_value,
+ ops.convert_to_tensor(partition_sizes, dtype=np.int32),
+ axis=axis)
+ op_list = []
+ for partition in partitioned_var:
+ op_list.append(
+ _variable_assign(partition, new_partitioned_values[len(op_list)]))
+ return control_flow_ops.group(
+ *op_list, name=partitioned_var.name + '_group_assign')
+
+
+def apply_mask(x, scope=''):
+ """Apply mask to a given weight tensor.
+
+ Args:
+ x: Input weight tensor
+ scope: The current variable scope. Defaults to ""
+ Returns:
+ Tensor representing masked_weights
+ """
+
+ mask = _weight_mask_variable(x, scope)
+ threshold = _weight_threshold_variable(x, scope)
+ # Add masked_weights in the weights namescope so as to make it easier
+ # for the quantization library to add quant ops.
+ masked_weights = math_ops.multiply(mask, x, _MASKED_WEIGHT_NAME)
+
+ # Make sure the mask for a given variable are not added multiple times to the
+ # collection. This is particularly important when applying mask to RNN's
+ # weight variables
+ if mask not in ops.get_collection_ref(_MASK_COLLECTION):
+ ops.add_to_collection(_THRESHOLD_COLLECTION, threshold)
+ ops.add_to_collection(_MASK_COLLECTION, mask)
+ ops.add_to_collection(_MASKED_WEIGHT_COLLECTION, masked_weights)
+ ops.add_to_collection(_WEIGHT_COLLECTION, x)
+ return masked_weights
+
+
+def get_masked_weights():
+ return ops.get_collection(_MASKED_WEIGHT_COLLECTION)
+
+
+def get_masks():
+ return ops.get_collection(_MASK_COLLECTION)
+
+
+def get_thresholds():
+ return ops.get_collection(_THRESHOLD_COLLECTION)
+
+
+def get_weights():
+ return ops.get_collection(_WEIGHT_COLLECTION)
+
+
+def get_weight_sparsity():
+ """Get sparsity of the weights.
+
+ Args:
+ None
+
+ Returns:
+ A list containing the sparsity of each of the weight tensors
+ """
+ masks = get_masks()
+ return [nn_impl.zero_fraction(mask) for mask in masks]
+
+
+def get_pruning_hparams():
+ """Get a tf.HParams object with the default values for the hyperparameters.
+
+ name: string
+ name of the pruning specification. Used for adding summaries and ops under
+ a common tensorflow name_scope
+ begin_pruning_step: integer
+ the global step at which to begin pruning
+ end_pruning_step: integer
+ the global step at which to terminate pruning. Defaults to -1 implying
+ that pruning continues till the training stops
+ do_not_prune: list of strings
+ list of layers that are not pruned
+ threshold_decay: float
+ the decay factor to use for exponential decay of the thresholds
+ pruning_frequency: integer
+ How often should the masks be updated? (in # of global_steps)
+ nbins: integer
+ number of bins to use for histogram computation
+ initial_sparsity: float
+ initial sparsity value
+ target_sparsity: float
+ target sparsity value
+ sparsity_function_begin_step: integer
+ the global step at this which the gradual sparsity function begins to
+ take effect
+ sparsity_function_end_step: integer
+ the global step used as the end point for the gradual sparsity function
+ sparsity_function_exponent: float
+ exponent = 1 is linearly varying sparsity between initial and final.
+ exponent > 1 varies more slowly towards the end than the beginning
+
+ We use the following sparsity function:
+
+ num_steps = (sparsity_function_end_step -
+ sparsity_function_begin_step)/pruning_frequency
+ sparsity(step) = (initial_sparsity - target_sparsity)*
+ [1-step/(num_steps -1)]**exponent + target_sparsity
+
+ Args:
+ None
+
+ Returns:
+ tf.HParams object initialized to default values
+
+ """
+ return hparam.HParams(
+ name='model_pruning',
+ begin_pruning_step=0,
+ end_pruning_step=-1,
+ do_not_prune=[''],
+ threshold_decay=0.9,
+ pruning_frequency=10,
+ nbins=255,
+ initial_sparsity=0,
+ target_sparsity=0.5,
+ sparsity_function_begin_step=0,
+ sparsity_function_end_step=100,
+ sparsity_function_exponent=3)
+
+
+class Pruning(object):
+
+ def __init__(self,
+ spec=None,
+ global_step=None,
+ sparsity=None,
+ partitioner=None):
+ """Set up the specification for model pruning.
+
+ If a spec is provided, the sparsity is set up based on the sparsity_function
+ in the spec. The effect of sparsity_function is overridden if the sparsity
+ variable is passed to the constructor. This enables setting up arbitrary
+ sparsity profiles externally and passing it to this pruning functions.
+
+ Args:
+ spec: Pruning spec as defined in pruning.proto
+ global_step: A tensorflow variable that is used while setting up the
+ sparsity function
+ sparsity: A tensorflow scalar variable storing the sparsity
+ partitioner: The tensorflow partitioner function used to distribute
+ parameters across shards
+ """
+ # Pruning specification
+ self._spec = spec if spec else get_pruning_hparams()
+
+ # A tensorflow variable that tracks the sparsity function.
+ # If not provided as input, the graph must already contain the global_step
+ # variable before calling this constructor.
+ self._global_step = self._setup_global_step(global_step)
+
+ # Stores the tensorflow sparsity variable.
+ # Built using self._setup_sparsity() or provided externally
+ self._sparsity = sparsity if sparsity else self._setup_sparsity()
+
+ # Stores the partitioner function uses to partition variables across tasks/
+ self._partitioner = partitioner
+
+ # List of tensorflow assignments ops for new masks and thresholds
+ self._assign_ops = []
+
+ # Tensorflow variable keeping track of the last global step when the masks
+ # were updated
+ self._last_update_step = self._setup_last_update_step()
+
+ def _setup_global_step(self, global_step):
+ graph_global_step = global_step
+ if graph_global_step is None:
+ graph_global_step = training_util.get_global_step()
+
+ return math_ops.cast(graph_global_step, np.int32)
+
+ def _setup_sparsity(self):
+ begin_step = self._spec.sparsity_function_begin_step
+ end_step = self._spec.sparsity_function_end_step
+ initial_sparsity = self._spec.initial_sparsity
+ target_sparsity = self._spec.target_sparsity
+ exponent = self._spec.sparsity_function_exponent
+
+ if begin_step >= end_step:
+ raise ValueError(
+ 'Pruning must begin before it can end. begin_step=%d, end_step=%d' %
+ (begin_step, end_step))
+
+ with ops.name_scope(self._spec.name):
+ p = math_ops.minimum(1.0,
+ math_ops.maximum(
+ 0.0,
+ math_ops.div(
+ math_ops.cast(self._global_step - begin_step,
+ np.float32),
+ end_step - begin_step)))
+ sparsity = math_ops.add(
+ math_ops.multiply(initial_sparsity - target_sparsity,
+ math_ops.pow(1 - p, exponent)),
+ target_sparsity,
+ name='sparsity')
+
+ return sparsity
+
+ def _setup_last_update_step(self):
+ with variable_scope.variable_scope(self._spec.name) as scope:
+ try:
+ last_update_step = variable_scope.get_variable(
+ 'last_mask_update_step', [],
+ initializer=init_ops.zeros_initializer(),
+ trainable=False,
+ dtype=np.int32)
+ except ValueError:
+ scope.reuse_variables()
+ last_update_step = variable_scope.get_variable(
+ 'last_mask_update_step', dtype=np.int32)
+ return last_update_step
+
+ def _exists_in_do_not_prune_list(self, tensor_name):
+ do_not_prune_list = self._spec.do_not_prune
+ if not do_not_prune_list[0]:
+ return False
+ for layer_name in do_not_prune_list:
+ if tensor_name.find(layer_name) != -1:
+ return True
+
+ return False
+
+ def _update_mask(self, weights, threshold):
+ """Updates the mask for a given weight tensor.
+
+ This functions first computes the cdf of the weight tensor, and estimates
+ the threshold value such that 'desired_sparsity' fraction of weights
+ have magnitude less than the threshold.
+
+ Args:
+ weights: The weight tensor that needs to be masked.
+ threshold: The current threshold value. The function will compute a new
+ threshold and return the exponential moving average using the current
+ value of threshold
+
+ Returns:
+ new_threshold: The new value of the threshold based on weights, and
+ desired_sparsity
+ new_mask: A n-D numpy array containing 0 or 1 to indicate which of the
+ values in weights falls below the threshold
+
+ Raises:
+ ValueError: if sparsity is not defined
+ """
+ if self._sparsity is None:
+ raise ValueError('Sparsity variable undefined')
+
+ with ops.name_scope(weights.op.name + '_pruning_ops'):
+ abs_weights = math_ops.abs(weights)
+ max_value = math_ops.reduce_max(abs_weights)
+ histogram = _histogram(
+ abs_weights, [0.0, max_value],
+ nbins=self._spec.nbins,
+ dtype=np.float32)
+
+ cdf = math_ops.cumsum(histogram)
+ norm_cdf = math_ops.div(cdf, math_ops.reduce_sum(histogram))
+ current_threshold = math_ops.multiply(
+ math_ops.div(
+ math_ops.reduce_sum(
+ math_ops.cast(
+ math_ops.less(norm_cdf, self._sparsity), np.float32)),
+ float(self._spec.nbins)), max_value)
+
+ smoothed_threshold = math_ops.add_n([
+ math_ops.multiply(current_threshold, 1 - self._spec.threshold_decay),
+ math_ops.multiply(threshold, self._spec.threshold_decay)
+ ])
+ new_mask = math_ops.cast(
+ math_ops.greater(abs_weights, smoothed_threshold), np.float32)
+ return smoothed_threshold, new_mask
+
+ def _get_mask_assign_ops(self):
+ # Make sure the assignment ops have not already been added to the list
+ if self._assign_ops:
+ raise ValueError(
+ 'Assign op list not empty. _get_mask_assign_ops() called twice?')
+
+ masks = get_masks()
+ weights = get_weights()
+ thresholds = get_thresholds()
+
+ if len(masks) != len(thresholds):
+ raise ValueError(
+ 'Number of masks %s and number of thresholds %s mismatch' %
+ (len(masks), len(thresholds)))
+
+ for index, mask in enumerate(masks):
+ threshold = thresholds[index]
+ weight = weights[index] if self._partitioner is None else weights[
+ index].as_tensor()
+
+ if self._spec.do_not_prune:
+ if self._exists_in_do_not_prune_list(mask.name):
+ continue
+
+ new_threshold, new_mask = self._update_mask(weight, threshold)
+ self._assign_ops.append(_variable_assign(threshold, new_threshold))
+ self._assign_ops.append(
+ _variable_assign(mask, new_mask) if self._partitioner is None else
+ _partitioned_variable_assign(mask, new_mask))
+
+ def mask_update_op(self):
+ with ops.name_scope(self._spec.name):
+ if not self._assign_ops:
+ self._get_mask_assign_ops()
+ with ops.control_dependencies([
+ state_ops.assign(
+ self._last_update_step,
+ self._global_step,
+ name='last_mask_update_step_assign')
+ ]):
+ with ops.control_dependencies(self._assign_ops):
+ logging.info('Updating masks.')
+ return control_flow_ops.no_op('mask_update')
+
+ def conditional_mask_update_op(self):
+
+ def maybe_update_masks():
+ with ops.name_scope(self._spec.name):
+ is_step_within_pruning_range = math_ops.logical_and(
+ math_ops.greater_equal(self._global_step,
+ self._spec.begin_pruning_step),
+ # If end_pruning_step is negative, keep pruning forever!
+ math_ops.logical_or(
+ math_ops.less_equal(self._global_step,
+ self._spec.end_pruning_step),
+ math_ops.less(self._spec.end_pruning_step, 0)))
+ is_pruning_step = math_ops.less_equal(
+ math_ops.add(self._last_update_step, self._spec.pruning_frequency),
+ self._global_step)
+ return math_ops.logical_and(is_step_within_pruning_range,
+ is_pruning_step)
+
+ def mask_update_op():
+ return self.mask_update_op()
+
+ def no_update_op():
+ return control_flow_ops.no_op()
+
+ return control_flow_ops.cond(maybe_update_masks(), mask_update_op,
+ no_update_op)
+
+ def add_pruning_summaries(self):
+ """Adds summaries for this pruning spec.
+
+ Args: none
+
+ Returns: none
+ """
+ with ops.name_scope(self._spec.name + '_summaries'):
+ summary.scalar('sparsity', self._sparsity)
+ summary.scalar('last_mask_update_step', self._last_update_step)
+ masks = get_masks()
+ thresholds = get_thresholds()
+ for index, mask in enumerate(masks):
+ if not self._exists_in_do_not_prune_list(mask.name):
+ summary.scalar(mask.name + '/sparsity', nn_impl.zero_fraction(mask))
+ summary.scalar(thresholds[index].op.name + '/threshold',
+ thresholds[index])
+
+ def print_hparams(self):
+ logging.info(self._spec.to_json())
diff --git a/tensorflow/contrib/model_pruning/python/pruning_test.py b/tensorflow/contrib/model_pruning/python/pruning_test.py
new file mode 100644
index 0000000000..c23fd649ce
--- /dev/null
+++ b/tensorflow/contrib/model_pruning/python/pruning_test.py
@@ -0,0 +1,162 @@
+# Copyright 2017 The TensorFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+"""Tests for the key functions in pruning library."""
+
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+import numpy as np
+
+from tensorflow.contrib.model_pruning.python import pruning
+from tensorflow.python.ops import math_ops
+from tensorflow.python.ops import partitioned_variables
+from tensorflow.python.ops import random_ops
+from tensorflow.python.ops import state_ops
+from tensorflow.python.ops import variable_scope
+from tensorflow.python.ops import variables
+from tensorflow.python.platform import test
+from tensorflow.python.training import training_util
+
+
+class PruningHParamsTest(test.TestCase):
+ PARAM_LIST = [
+ "name=test", "threshold_decay=0.9", "pruning_frequency=10",
+ "do_not_prune=[conv1,conv2]", "sparsity_function_end_step=100",
+ "target_sparsity=0.9"
+ ]
+ TEST_HPARAMS = ",".join(PARAM_LIST)
+
+ def setUp(self):
+ super(PruningHParamsTest, self).setUp()
+ # Add global step variable to the graph
+ self.global_step = training_util.get_or_create_global_step()
+ # Add sparsity
+ self.sparsity = variables.Variable(0.5, name="sparsity")
+ # Parse hparams
+ self.pruning_hparams = pruning.get_pruning_hparams().parse(
+ self.TEST_HPARAMS)
+
+ def testInit(self):
+ p = pruning.Pruning(self.pruning_hparams)
+ self.assertEqual(p._spec.name, "test")
+ self.assertAlmostEqual(p._spec.threshold_decay, 0.9)
+ self.assertEqual(p._spec.pruning_frequency, 10)
+ self.assertAllEqual(p._spec.do_not_prune, ["conv1", "conv2"])
+ self.assertEqual(p._spec.sparsity_function_end_step, 100)
+ self.assertAlmostEqual(p._spec.target_sparsity, 0.9)
+
+ def testInitWithExternalSparsity(self):
+ with self.test_session():
+ p = pruning.Pruning(spec=self.pruning_hparams, sparsity=self.sparsity)
+ variables.global_variables_initializer().run()
+ sparsity = p._sparsity.eval()
+ self.assertAlmostEqual(sparsity, 0.5)
+
+ def testInitWithVariableReuse(self):
+ with self.test_session():
+ p = pruning.Pruning(spec=self.pruning_hparams, sparsity=self.sparsity)
+ p_copy = pruning.Pruning(
+ spec=self.pruning_hparams, sparsity=self.sparsity)
+ variables.global_variables_initializer().run()
+ sparsity = p._sparsity.eval()
+ self.assertAlmostEqual(sparsity, 0.5)
+ self.assertEqual(p._sparsity.eval(), p_copy._sparsity.eval())
+
+
+class PruningTest(test.TestCase):
+
+ def setUp(self):
+ super(PruningTest, self).setUp()
+ self.global_step = training_util.get_or_create_global_step()
+
+ def testCreateMask2D(self):
+ width = 10
+ height = 20
+ with self.test_session():
+ weights = variables.Variable(
+ random_ops.random_normal([width, height], stddev=1), name="weights")
+ masked_weights = pruning.apply_mask(weights,
+ variable_scope.get_variable_scope())
+ variables.global_variables_initializer().run()
+ weights_val = weights.eval()
+ masked_weights_val = masked_weights.eval()
+ self.assertAllEqual(weights_val, masked_weights_val)
+
+ def testUpdateSingleMask(self):
+ with self.test_session() as session:
+ weights = variables.Variable(
+ math_ops.linspace(1.0, 100.0, 100), name="weights")
+ masked_weights = pruning.apply_mask(weights)
+ sparsity = variables.Variable(0.5, name="sparsity")
+ p = pruning.Pruning(sparsity=sparsity)
+ p._spec.threshold_decay = 0.0
+ mask_update_op = p.mask_update_op()
+ variables.global_variables_initializer().run()
+ masked_weights_val = masked_weights.eval()
+ self.assertAllEqual(np.count_nonzero(masked_weights_val), 100)
+ session.run(mask_update_op)
+ masked_weights_val = masked_weights.eval()
+ self.assertAllEqual(np.count_nonzero(masked_weights_val), 51)
+
+ def testPartitionedVariableMasking(self):
+ partitioner = partitioned_variables.variable_axis_size_partitioner(40)
+ with self.test_session() as session:
+ with variable_scope.variable_scope("", partitioner=partitioner):
+ sparsity = variables.Variable(0.5, name="Sparsity")
+ weights = variable_scope.get_variable(
+ "weights", initializer=math_ops.linspace(1.0, 100.0, 100))
+ masked_weights = pruning.apply_mask(
+ weights, scope=variable_scope.get_variable_scope())
+ p = pruning.Pruning(sparsity=sparsity, partitioner=partitioner)
+ p._spec.threshold_decay = 0.0
+ mask_update_op = p.mask_update_op()
+ variables.global_variables_initializer().run()
+ masked_weights_val = masked_weights.eval()
+ session.run(mask_update_op)
+ masked_weights_val = masked_weights.eval()
+ self.assertAllEqual(np.count_nonzero(masked_weights_val), 51)
+
+ def testConditionalMaskUpdate(self):
+ param_list = [
+ "pruning_frequency=2", "begin_pruning_step=1", "end_pruning_step=6"
+ ]
+ test_spec = ",".join(param_list)
+ pruning_hparams = pruning.get_pruning_hparams().parse(test_spec)
+ weights = variables.Variable(
+ math_ops.linspace(1.0, 100.0, 100), name="weights")
+ masked_weights = pruning.apply_mask(weights)
+ sparsity = variables.Variable(0.00, name="sparsity")
+ # Set up pruning
+ p = pruning.Pruning(pruning_hparams, sparsity=sparsity)
+ p._spec.threshold_decay = 0.0
+ mask_update_op = p.conditional_mask_update_op()
+ sparsity_val = math_ops.linspace(0.0, 0.9, 10)
+ increment_global_step = state_ops.assign_add(self.global_step, 1)
+ non_zero_count = []
+ with self.test_session() as session:
+ variables.global_variables_initializer().run()
+ for i in range(10):
+ session.run(state_ops.assign(sparsity, sparsity_val[i]))
+ session.run(mask_update_op)
+ session.run(increment_global_step)
+ non_zero_count.append(np.count_nonzero(masked_weights.eval()))
+ # Weights pruned at steps 0,2,4,and,6
+ expected_non_zero_count = [100, 100, 80, 80, 60, 60, 40, 40, 40, 40]
+ self.assertAllEqual(expected_non_zero_count, non_zero_count)
+
+
+if __name__ == "__main__":
+ test.main()