aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/contrib/model_pruning
diff options
context:
space:
mode:
authorGravatar Suyog Gupta <suyoggupta@google.com>2018-08-08 10:00:54 -0700
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2018-08-08 10:04:58 -0700
commit152c5563ca2474ec9394442086e53b26b41e0773 (patch)
tree6fe7e4dac9b33111ea5097b10cb0898944ba5172 /tensorflow/contrib/model_pruning
parentc3d102c47a8c4cacdb7a4e055224ac2aabf2b578 (diff)
Add utility function to pruning library to strip a trained graph of pruning-related variables
PiperOrigin-RevId: 207902316
Diffstat (limited to 'tensorflow/contrib/model_pruning')
-rw-r--r--tensorflow/contrib/model_pruning/BUILD42
-rw-r--r--tensorflow/contrib/model_pruning/README.md46
-rw-r--r--tensorflow/contrib/model_pruning/__init__.py6
-rw-r--r--tensorflow/contrib/model_pruning/python/pruning_test.py1
-rw-r--r--tensorflow/contrib/model_pruning/python/strip_pruning_vars.py103
-rw-r--r--tensorflow/contrib/model_pruning/python/strip_pruning_vars_lib.py142
-rw-r--r--tensorflow/contrib/model_pruning/python/strip_pruning_vars_test.py232
7 files changed, 562 insertions, 10 deletions
diff --git a/tensorflow/contrib/model_pruning/BUILD b/tensorflow/contrib/model_pruning/BUILD
index 54bd39afac..16ddc38f5a 100644
--- a/tensorflow/contrib/model_pruning/BUILD
+++ b/tensorflow/contrib/model_pruning/BUILD
@@ -95,6 +95,22 @@ py_library(
],
)
+py_library(
+ name = "strip_pruning_vars_lib",
+ srcs = ["python/strip_pruning_vars_lib.py"],
+ srcs_version = "PY2AND3",
+ visibility = ["//visibility:public"],
+ deps = [
+ ":pruning",
+ "//tensorflow/python:client",
+ "//tensorflow/python:framework",
+ "//tensorflow/python:platform",
+ "//tensorflow/python:training",
+ "//third_party/py/numpy",
+ "@six_archive//:six",
+ ],
+)
+
py_test(
name = "pruning_utils_test",
size = "small",
@@ -129,6 +145,31 @@ py_test(
],
)
+py_test(
+ name = "strip_pruning_vars_test",
+ size = "small",
+ srcs = ["python/strip_pruning_vars_test.py"],
+ srcs_version = "PY2AND3",
+ deps = [
+ ":layers",
+ ":pruning",
+ ":rnn_cells",
+ ":strip_pruning_vars_lib",
+ "//tensorflow/python:client_testlib",
+ ],
+)
+
+py_binary(
+ name = "strip_pruning_vars",
+ srcs = ["python/strip_pruning_vars.py"],
+ srcs_version = "PY2AND3",
+ visibility = ["//visibility:public"],
+ deps = [
+ ":strip_pruning_vars_lib",
+ "//tensorflow/python:platform",
+ ],
+)
+
py_library(
name = "init_py",
srcs = ["__init__.py"],
@@ -145,5 +186,6 @@ py_library(
":learning",
":pruning",
":rnn_cells",
+ ":strip_pruning_vars_lib",
],
)
diff --git a/tensorflow/contrib/model_pruning/README.md b/tensorflow/contrib/model_pruning/README.md
index dbe4e124fd..0761dea900 100644
--- a/tensorflow/contrib/model_pruning/README.md
+++ b/tensorflow/contrib/model_pruning/README.md
@@ -4,7 +4,15 @@ This document describes the API that facilitates magnitude-based pruning of
neural network's weight tensors. The API helps inject necessary tensorflow op
into the training graph so the model can be pruned while it is being trained.
-### Model creation
+## Table of contents
+1. [Model creation](# model-creation)
+2. [Hyperparameters for pruning](#hyperparameters)
+ - [Block sparsity](#block-sparsity)
+3. [Adding pruning ops to the training graph](#adding-pruning-ops)
+4. [Removing pruning ops from trained model](#remove)
+5. [Example](#example)
+
+### Model creation <a name="model-creation"></a>
The first step involves adding mask and threshold variables to the layers that
need to undergo pruning. The variable mask is the same shape as the layer's
@@ -33,7 +41,7 @@ auxiliary variables built-in (see
* [rnn_cells.MaskedLSTMCell](https://github.com/tensorflow/tensorflow/tree/master/tensorflow/contrib/model_pruning/python/layers/rnn_cells.py?l=154)
-### Adding pruning ops to the training graph
+### Pruning-related hyperparameters <a name="hyperparameters"></a>
The pruning library allows for specification of the following hyper parameters:
@@ -64,7 +72,13 @@ is divided into $$n$$ intervals of size equal to the pruning_frequency ($$\Delta
t$$). $$s_f$$ is the target_sparsity, $$s_i$$ is the initial_sparsity, $$t_0$$
is the sparsity_function_begin_step. In this equation, the
sparsity_function_exponent is set to 3.
-### Adding pruning ops to the training graph
+
+#### Block Sparsity <a name="block-sparsity"></a>
+
+For some hardware architectures, it may be beneficial to induce spatially correlated sparsity. To train models in which the weight tensors have block sparse structure, set *block_height* and *block_width* hyperparameters to the desired block configuration (2x2, 4x4, 4x1, 1x8, etc). Currently, block sparsity is only supported for weight tensors which can be squeezed to rank 2. The matrix is partitioned into non-overlapping blocks of size *[block_height, block_dim]* and the either the average or max absolute value in this block is taken as a proxy for the entire block (set by *block_pooling_function* hyperparameter).
+The convolution layer tensors are always pruned used block dimensions of [1,1].
+
+### Adding pruning ops to the training graph <a name="adding-pruning-ops"></a>
The final step involves adding ops to the training graph that monitor the
distribution of the layer's weight magnitudes and determine the layer threshold,
@@ -105,7 +119,19 @@ with tf.graph.as_default():
```
Ensure that `global_step` is being [incremented](https://www.tensorflow.org/api_docs/python/tf/train/Optimizer#minimize), otherwise pruning will not work!
-## Example: Pruning and training deep CNNs on the cifar10 dataset
+### Removing pruning ops from the trained graph <a name="remove"></a>
+Once the model is trained, it is necessary to remove the auxiliary variables (mask, threshold) and pruning ops added to the graph in the steps above. This can be accomplished using the `strip_pruning_vars` utility.
+
+This utility generates a binary GraphDef in which the variables have been converted to constants. In particular, the threshold variables are removed from the graph and the mask variable is fused with the corresponding weight tensor to produce a `masked_weight` tensor. This tensor is sparse, has the same size as the weight tensor, and the sparsity is as set by the `target_sparsity` or the `weight_sparsity_map` hyperparameters above.
+
+```shell
+$ bazel build -c opt contrib/model_pruning:strip_pruning_vars
+$ bazel-bin/contrib/model_pruning/strip_pruning_vars --checkpoint_dir=/path/to/checkpoints/ --output_node_names=graph_node1,graph_node2 --output_dir=/tmp --filename=pruning_stripped.pb
+```
+
+For now, it is assumed that the underlying hardware platform will provide mechanisms for compressing the sparse tensors and/or accelerating the sparse tensor computations.
+
+## Example: Pruning and training deep CNNs on the cifar10 dataset <a name="example"></a>
Please see https://www.tensorflow.org/tutorials/deep_cnn for details on neural
network architecture, setting up inputs etc. The additional changes needed to
@@ -121,7 +147,7 @@ incorporate pruning are captured in the following:
To train the pruned version of cifar10:
-```bash
+```shell
$ examples_dir=contrib/model_pruning/examples
$ bazel build -c opt $examples_dir/cifar10:cifar10_{train,eval}
$ bazel-bin/$examples_dir/cifar10/cifar10_train --pruning_hparams=name=cifar10_pruning,begin_pruning_step=10000,end_pruning_step=100000,target_sparsity=0.9,sparsity_function_begin_step=10000,sparsity_function_end_step=100000
@@ -133,10 +159,14 @@ Eval:
$ bazel-bin/$examples_dir/cifar10/cifar10_eval --run_once
```
-### Block Sparsity
+Removing pruning nodes from the trained graph:
-For some hardware architectures, it may be beneficial to induce spatially correlated sparsity. To train models in which the weight tensors have block sparse structure, set *block_height* and *block_width* hyperparameters to the desired block configuration (2x2, 4x4, 4x1, 1x8, etc). Currently, block sparsity is only supported for weight tensors which can be squeezed to rank 2. The matrix is partitioned into non-overlapping blocks of size *[block_height, block_dim]* and the either the average or max absolute value in this block is taken as a proxy for the entire block (set by *block_pooling_function* hyperparameter).
-The convolution layer tensors are always pruned used block dimensions of [1,1].
+```shell
+$ bazel build -c opt contrib/model_pruning:strip_pruning_vars
+$ bazel-bin/contrib/model_pruning/strip_pruning_vars --checkpoint_path=/tmp/cifar10_train --output_node_names=softmax_linear/softmax_linear_2 --filename=cifar_pruned.pb
+```
+
+The generated GraphDef (cifar_pruned.pb) may be visualized using the [`import_pb_to_tensorboard`](https://github.com/tensorflow/tensorflow/tree/master/tensorflow/python/tools/import_pb_to_tensorboard.py) utility
## References
diff --git a/tensorflow/contrib/model_pruning/__init__.py b/tensorflow/contrib/model_pruning/__init__.py
index d32bedbcd6..6eca54aaee 100644
--- a/tensorflow/contrib/model_pruning/__init__.py
+++ b/tensorflow/contrib/model_pruning/__init__.py
@@ -33,6 +33,9 @@ from tensorflow.contrib.model_pruning.python.pruning import get_thresholds
from tensorflow.contrib.model_pruning.python.pruning import get_weight_sparsity
from tensorflow.contrib.model_pruning.python.pruning import get_weights
from tensorflow.contrib.model_pruning.python.pruning import Pruning
+from tensorflow.contrib.model_pruning.python.strip_pruning_vars_lib import graph_def_from_checkpoint
+from tensorflow.contrib.model_pruning.python.strip_pruning_vars_lib import strip_pruning_vars_fn
+
# pylint: enable=unused-import
from tensorflow.python.util.all_util import remove_undocumented
@@ -41,7 +44,8 @@ _allowed_symbols = [
'masked_convolution', 'masked_conv2d', 'masked_fully_connected',
'MaskedBasicLSTMCell', 'MaskedLSTMCell', 'train', 'apply_mask',
'get_masked_weights', 'get_masks', 'get_pruning_hparams', 'get_thresholds',
- 'get_weights', 'get_weight_sparsity', 'Pruning'
+ 'get_weights', 'get_weight_sparsity', 'Pruning', 'strip_pruning_vars_fn',
+ 'graph_def_from_checkpoint'
]
remove_undocumented(__name__, _allowed_symbols)
diff --git a/tensorflow/contrib/model_pruning/python/pruning_test.py b/tensorflow/contrib/model_pruning/python/pruning_test.py
index 5b67656e9f..33c4ad58bd 100644
--- a/tensorflow/contrib/model_pruning/python/pruning_test.py
+++ b/tensorflow/contrib/model_pruning/python/pruning_test.py
@@ -60,7 +60,6 @@ class PruningHParamsTest(test.TestCase):
self.assertEqual(p._weight_sparsity_map["conv1"], 0.8)
self.assertEqual(p._weight_sparsity_map["conv2/kernel"], 0.8)
-
def testInitWithExternalSparsity(self):
with self.test_session():
p = pruning.Pruning(spec=self.pruning_hparams, sparsity=self.sparsity)
diff --git a/tensorflow/contrib/model_pruning/python/strip_pruning_vars.py b/tensorflow/contrib/model_pruning/python/strip_pruning_vars.py
new file mode 100644
index 0000000000..3385103807
--- /dev/null
+++ b/tensorflow/contrib/model_pruning/python/strip_pruning_vars.py
@@ -0,0 +1,103 @@
+# Copyright 2018 The TensorFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+r"""Removes the auxiliary variables and ops added by the pruning library.
+
+Usage:
+
+bazel build tensorflow/contrib/model_pruning:strip_pruning_vars && \
+bazel-bin/tensorflow/contrib/model_pruning/strip_pruning_vars \
+--checkpoint_dir=/tmp/model_ckpts \
+--output_node_names=softmax \
+--output_dir=/tmp \
+--filename=pruning_stripped.pb
+"""
+
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+import argparse
+import os
+import sys
+
+from tensorflow.contrib.model_pruning.python import strip_pruning_vars_lib
+from tensorflow.python.framework import graph_io
+from tensorflow.python.platform import app
+from tensorflow.python.platform import tf_logging as logging
+
+FLAGS = None
+
+
+def strip_pruning_vars(checkpoint_dir, output_node_names, output_dir, filename):
+ """Remove pruning-related auxiliary variables and ops from the graph.
+
+ Accepts training checkpoints and produces a GraphDef in which the pruning vars
+ and ops have been removed.
+
+ Args:
+ checkpoint_dir: Path to the checkpoints.
+ output_node_names: The name of the output nodes, comma separated.
+ output_dir: Directory where to write the graph.
+ filename: Output GraphDef file name.
+
+ Returns:
+ None
+
+ Raises:
+ ValueError: if output_nodes_names are not provided.
+ """
+ if not output_node_names:
+ raise ValueError(
+ 'Need to specify atleast 1 output node through output_node_names flag')
+ output_node_names = output_node_names.replace(' ', '').split(',')
+
+ initial_graph_def = strip_pruning_vars_lib.graph_def_from_checkpoint(
+ checkpoint_dir, output_node_names)
+
+ final_graph_def = strip_pruning_vars_lib.strip_pruning_vars_fn(
+ initial_graph_def, output_node_names)
+ graph_io.write_graph(final_graph_def, output_dir, filename, as_text=False)
+ logging.info('\nFinal graph written to %s', os.path.join(
+ output_dir, filename))
+
+
+def main(unused_args):
+ return strip_pruning_vars(FLAGS.checkpoint_dir, FLAGS.output_node_names,
+ FLAGS.output_dir, FLAGS.filename)
+
+
+if __name__ == '__main__':
+ parser = argparse.ArgumentParser()
+ parser.register('type', 'bool', lambda v: v.lower() == 'true')
+ parser.add_argument(
+ '--checkpoint_dir', type=str, default='', help='Path to the checkpoints.')
+ parser.add_argument(
+ '--output_node_names',
+ type=str,
+ default='',
+ help='The name of the output nodes, comma separated.')
+ parser.add_argument(
+ '--output_dir',
+ type=str,
+ default='/tmp',
+ help='Directory where to write the graph.')
+ parser.add_argument(
+ '--filename',
+ type=str,
+ default='pruning_stripped.pb',
+ help='Output \'GraphDef\' file name.')
+
+ FLAGS, unparsed = parser.parse_known_args()
+ app.run(main=main, argv=[sys.argv[0]] + unparsed)
diff --git a/tensorflow/contrib/model_pruning/python/strip_pruning_vars_lib.py b/tensorflow/contrib/model_pruning/python/strip_pruning_vars_lib.py
new file mode 100644
index 0000000000..fc4b10863f
--- /dev/null
+++ b/tensorflow/contrib/model_pruning/python/strip_pruning_vars_lib.py
@@ -0,0 +1,142 @@
+# Copyright 2018 The TensorFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+"""Utilities to remove pruning-related ops and variables from a GraphDef.
+"""
+
+# pylint: disable=missing-docstring
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+import numpy as np
+
+from tensorflow.core.framework import attr_value_pb2
+from tensorflow.core.framework import graph_pb2
+from tensorflow.core.framework import node_def_pb2
+from tensorflow.python.client import session
+from tensorflow.python.framework import graph_util
+from tensorflow.python.framework import importer
+from tensorflow.python.framework import ops
+from tensorflow.python.framework import tensor_util
+from tensorflow.python.platform import tf_logging as logging
+from tensorflow.python.training import saver as saver_lib
+
+
+def _node_name(tensor_name):
+ """Remove the trailing ':0' from the variable name."""
+ if ':' not in tensor_name:
+ return tensor_name
+
+ return tensor_name.split(':')[0]
+
+
+def _tensor_name(node_name):
+ """Appends the :0 in the op name to get the canonical tensor name."""
+ if ':' in node_name:
+ return node_name
+
+ return node_name + ':0'
+
+
+def _get_masked_weights(input_graph_def):
+ """Extracts masked_weights from the graph as a dict of {var_name:ndarray}."""
+ input_graph = ops.Graph()
+ with input_graph.as_default():
+ importer.import_graph_def(input_graph_def, name='')
+
+ with session.Session(graph=input_graph) as sess:
+ masked_weights_dict = {}
+ for node in input_graph_def.node:
+ if 'masked_weight' in node.name:
+ masked_weight_val = sess.run(
+ sess.graph.get_tensor_by_name(_tensor_name(node.name)))
+ logging.info(
+ '%s has %d values, %1.2f%% zeros \n', node.name,
+ np.size(masked_weight_val),
+ 100 - float(100 * np.count_nonzero(masked_weight_val)) /
+ np.size(masked_weight_val))
+ masked_weights_dict.update({node.name: masked_weight_val})
+ return masked_weights_dict
+
+
+def strip_pruning_vars_fn(input_graph_def, output_node_names):
+ """Removes mask variable from the graph.
+
+ Replaces the masked_weight tensor with element-wise multiplication of mask
+ and the corresponding weight variable.
+
+ Args:
+ input_graph_def: A GraphDef in which the variables have been converted to
+ constants. This is typically the output of
+ tf.graph_util.convert_variables_to_constant()
+ output_node_names: List of name strings for the result nodes of the graph
+
+ Returns:
+ A GraphDef in which pruning-related variables have been removed
+ """
+ masked_weights_dict = _get_masked_weights(input_graph_def)
+ pruned_graph_def = graph_pb2.GraphDef()
+
+ # Replace masked_weight with a const op containing the
+ # result of tf.multiply(mask,weight)
+ for node in input_graph_def.node:
+ output_node = node_def_pb2.NodeDef()
+ if 'masked_weight' in node.name:
+ output_node.op = 'Const'
+ output_node.name = node.name
+ dtype = node.attr['T']
+ data = masked_weights_dict[node.name]
+ output_node.attr['dtype'].CopyFrom(dtype)
+ output_node.attr['value'].CopyFrom(
+ attr_value_pb2.AttrValue(tensor=tensor_util.make_tensor_proto(data)))
+
+ else:
+ output_node.CopyFrom(node)
+ pruned_graph_def.node.extend([output_node])
+
+ # Remove stranded nodes: mask and weights
+ return graph_util.extract_sub_graph(pruned_graph_def, output_node_names)
+
+
+def graph_def_from_checkpoint(checkpoint_dir, output_node_names):
+ """Converts checkpoint data to GraphDef.
+
+ Reads the latest checkpoint data and produces a GraphDef in which the
+ variables have been converted to constants.
+
+ Args:
+ checkpoint_dir: Path to the checkpoints.
+ output_node_names: List of name strings for the result nodes of the graph.
+
+ Returns:
+ A GraphDef from the latest checkpoint
+
+ Raises:
+ ValueError: if no checkpoint is found
+ """
+ checkpoint_path = saver_lib.latest_checkpoint(checkpoint_dir)
+ if checkpoint_path is None:
+ raise ValueError('Could not find a checkpoint at: {0}.'
+ .format(checkpoint_dir))
+
+ saver_for_restore = saver_lib.import_meta_graph(
+ checkpoint_path + '.meta', clear_devices=True)
+ with session.Session() as sess:
+ saver_for_restore.restore(sess, checkpoint_path)
+ graph_def = ops.get_default_graph().as_graph_def()
+ output_graph_def = graph_util.convert_variables_to_constants(
+ sess, graph_def, output_node_names)
+
+ return output_graph_def
diff --git a/tensorflow/contrib/model_pruning/python/strip_pruning_vars_test.py b/tensorflow/contrib/model_pruning/python/strip_pruning_vars_test.py
new file mode 100644
index 0000000000..255daa0360
--- /dev/null
+++ b/tensorflow/contrib/model_pruning/python/strip_pruning_vars_test.py
@@ -0,0 +1,232 @@
+# Copyright 2018 The TensorFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+"""Tests for strip_pruning_vars."""
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+import re
+
+from tensorflow.contrib.model_pruning.python import pruning
+from tensorflow.contrib.model_pruning.python import strip_pruning_vars_lib
+from tensorflow.contrib.model_pruning.python.layers import layers
+from tensorflow.contrib.model_pruning.python.layers import rnn_cells
+from tensorflow.python.framework import dtypes
+from tensorflow.python.framework import graph_util
+from tensorflow.python.framework import importer
+from tensorflow.python.framework import ops
+from tensorflow.python.ops import array_ops
+from tensorflow.python.ops import random_ops
+from tensorflow.python.ops import rnn
+from tensorflow.python.ops import rnn_cell as tf_rnn_cells
+from tensorflow.python.ops import state_ops
+from tensorflow.python.ops import variable_scope
+from tensorflow.python.ops import variables
+from tensorflow.python.platform import test
+from tensorflow.python.training import training_util
+
+
+def _get_number_pruning_vars(graph_def):
+ number_vars = 0
+ for node in graph_def.node:
+ if re.match(r"^.*(mask$)|(threshold$)", node.name):
+ number_vars += 1
+ return number_vars
+
+
+def _get_node_names(tensor_names):
+ return [
+ strip_pruning_vars_lib._node_name(tensor_name)
+ for tensor_name in tensor_names
+ ]
+
+
+class StripPruningVarsTest(test.TestCase):
+
+ def setUp(self):
+ param_list = [
+ "pruning_frequency=1", "begin_pruning_step=1", "end_pruning_step=10",
+ "nbins=2048", "threshold_decay=0.0"
+ ]
+ self.initial_graph = ops.Graph()
+ self.initial_graph_def = None
+ self.final_graph = ops.Graph()
+ self.final_graph_def = None
+ self.pruning_spec = ",".join(param_list)
+ with self.initial_graph.as_default():
+ self.sparsity = variables.Variable(0.5, name="sparsity")
+ self.global_step = training_util.get_or_create_global_step()
+ self.increment_global_step = state_ops.assign_add(self.global_step, 1)
+ self.mask_update_op = None
+
+ def _build_convolutional_model(self, number_of_layers):
+ # Create a graph with several conv2d layers
+ kernel_size = 3
+ base_depth = 4
+ depth_step = 7
+ height, width = 7, 9
+ with variable_scope.variable_scope("conv_model"):
+ input_tensor = array_ops.ones((8, height, width, base_depth))
+ top_layer = input_tensor
+ for ix in range(number_of_layers):
+ top_layer = layers.masked_conv2d(
+ top_layer,
+ base_depth + (ix + 1) * depth_step,
+ kernel_size,
+ scope="Conv_" + str(ix))
+
+ return top_layer
+
+ def _build_fully_connected_model(self, number_of_layers):
+ base_depth = 4
+ depth_step = 7
+
+ input_tensor = array_ops.ones((8, base_depth))
+
+ top_layer = input_tensor
+
+ with variable_scope.variable_scope("fc_model"):
+ for ix in range(number_of_layers):
+ top_layer = layers.masked_fully_connected(
+ top_layer, base_depth + (ix + 1) * depth_step)
+
+ return top_layer
+
+ def _build_lstm_model(self, number_of_layers):
+ batch_size = 8
+ dim = 10
+ inputs = variables.Variable(random_ops.random_normal([batch_size, dim]))
+
+ def lstm_cell():
+ return rnn_cells.MaskedBasicLSTMCell(
+ dim, forget_bias=0.0, state_is_tuple=True, reuse=False)
+
+ cell = tf_rnn_cells.MultiRNNCell(
+ [lstm_cell() for _ in range(number_of_layers)], state_is_tuple=True)
+
+ outputs = rnn.static_rnn(
+ cell, [inputs],
+ initial_state=cell.zero_state(batch_size, dtypes.float32))
+
+ return outputs
+
+ def _prune_model(self, session):
+ pruning_hparams = pruning.get_pruning_hparams().parse(self.pruning_spec)
+ p = pruning.Pruning(pruning_hparams, sparsity=self.sparsity)
+ self.mask_update_op = p.conditional_mask_update_op()
+
+ variables.global_variables_initializer().run()
+ for _ in range(20):
+ session.run(self.mask_update_op)
+ session.run(self.increment_global_step)
+
+ def _get_outputs(self, session, input_graph, tensors_list, graph_prefix=None):
+ outputs = []
+
+ for output_tensor in tensors_list:
+ if graph_prefix:
+ output_tensor = graph_prefix + "/" + output_tensor
+ outputs.append(
+ session.run(session.graph.get_tensor_by_name(output_tensor)))
+
+ return outputs
+
+ def _get_initial_outputs(self, output_tensor_names_list):
+ with self.test_session(graph=self.initial_graph) as sess1:
+ self._prune_model(sess1)
+ reference_outputs = self._get_outputs(sess1, self.initial_graph,
+ output_tensor_names_list)
+
+ self.initial_graph_def = graph_util.convert_variables_to_constants(
+ sess1, sess1.graph.as_graph_def(),
+ _get_node_names(output_tensor_names_list))
+ return reference_outputs
+
+ def _get_final_outputs(self, output_tensor_names_list):
+ self.final_graph_def = strip_pruning_vars_lib.strip_pruning_vars_fn(
+ self.initial_graph_def, _get_node_names(output_tensor_names_list))
+ _ = importer.import_graph_def(self.final_graph_def, name="final")
+
+ with self.test_session(self.final_graph) as sess2:
+ final_outputs = self._get_outputs(
+ sess2,
+ self.final_graph,
+ output_tensor_names_list,
+ graph_prefix="final")
+ return final_outputs
+
+ def _check_removal_of_pruning_vars(self, number_masked_layers):
+ self.assertEqual(
+ _get_number_pruning_vars(self.initial_graph_def), number_masked_layers)
+ self.assertEqual(_get_number_pruning_vars(self.final_graph_def), 0)
+
+ def _check_output_equivalence(self, initial_outputs, final_outputs):
+ for initial_output, final_output in zip(initial_outputs, final_outputs):
+ self.assertAllEqual(initial_output, final_output)
+
+ def testConvolutionalModel(self):
+ with self.initial_graph.as_default():
+ number_masked_conv_layers = 5
+ top_layer = self._build_convolutional_model(number_masked_conv_layers)
+ output_tensor_names = [top_layer.name]
+ initial_outputs = self._get_initial_outputs(output_tensor_names)
+
+ # Remove pruning-related nodes.
+ with self.final_graph.as_default():
+ final_outputs = self._get_final_outputs(output_tensor_names)
+
+ # Check that the final graph has no pruning-related vars
+ self._check_removal_of_pruning_vars(number_masked_conv_layers)
+
+ # Check that outputs remain the same after removal of pruning-related nodes
+ self._check_output_equivalence(initial_outputs, final_outputs)
+
+ def testFullyConnectedModel(self):
+ with self.initial_graph.as_default():
+ number_masked_fc_layers = 3
+ top_layer = self._build_fully_connected_model(number_masked_fc_layers)
+ output_tensor_names = [top_layer.name]
+ initial_outputs = self._get_initial_outputs(output_tensor_names)
+
+ # Remove pruning-related nodes.
+ with self.final_graph.as_default():
+ final_outputs = self._get_final_outputs(output_tensor_names)
+
+ # Check that the final graph has no pruning-related vars
+ self._check_removal_of_pruning_vars(number_masked_fc_layers)
+
+ # Check that outputs remain the same after removal of pruning-related nodes
+ self._check_output_equivalence(initial_outputs, final_outputs)
+
+ def testLSTMModel(self):
+ with self.initial_graph.as_default():
+ number_masked_lstm_layers = 2
+ outputs = self._build_lstm_model(number_masked_lstm_layers)
+ output_tensor_names = [outputs[0][0].name]
+ initial_outputs = self._get_initial_outputs(output_tensor_names)
+
+ # Remove pruning-related nodes.
+ with self.final_graph.as_default():
+ final_outputs = self._get_final_outputs(output_tensor_names)
+
+ # Check that the final graph has no pruning-related vars
+ self._check_removal_of_pruning_vars(number_masked_lstm_layers)
+
+ # Check that outputs remain the same after removal of pruning-related nodes
+ self._check_output_equivalence(initial_outputs, final_outputs)
+
+
+if __name__ == "__main__":
+ test.main()