aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/contrib/model_pruning/python/pruning_test.py
diff options
context:
space:
mode:
authorGravatar A. Unique TensorFlower <gardener@tensorflow.org>2018-02-27 10:27:28 -0800
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2018-02-27 10:34:01 -0800
commite20be23387a6c1b72f3e34d03d4206c3211c921a (patch)
tree11e3d90ab8ddad8da4d1e502fedf1aae9223b288 /tensorflow/contrib/model_pruning/python/pruning_test.py
parente929b16dc89f62a41bcaba57b98ddd221bf9bf68 (diff)
Make block-based pruning more general, allowing it to operate on higher-dimensional arrays that can be squeezed to 2-dimensional.
PiperOrigin-RevId: 187195105
Diffstat (limited to 'tensorflow/contrib/model_pruning/python/pruning_test.py')
-rw-r--r--tensorflow/contrib/model_pruning/python/pruning_test.py17
1 files changed, 17 insertions, 0 deletions
diff --git a/tensorflow/contrib/model_pruning/python/pruning_test.py b/tensorflow/contrib/model_pruning/python/pruning_test.py
index 1767b4bb94..89e6571319 100644
--- a/tensorflow/contrib/model_pruning/python/pruning_test.py
+++ b/tensorflow/contrib/model_pruning/python/pruning_test.py
@@ -142,6 +142,23 @@ class PruningTest(test.TestCase):
self._blockMasking(param_list + ["block_pooling_function=MAX"], weights_max,
expected_mask)
+ self._blockMasking(param_list + ["block_pooling_function=AVG"], weights_avg,
+ expected_mask)
+
+ def testBlockMaskingWithHigherDimensions(self):
+ param_list = ["block_height=2", "block_width=2", "threshold_decay=0"]
+
+ # Weights as in testBlockMasking, but with one extra dimension.
+ weights_avg = constant_op.constant(
+ [[[0.1, 0.1, 0.2, 0.2], [0.1, 0.1, 0.2, 0.2], [0.3, 0.3, 0.4, 0.4],
+ [0.3, 0.3, 0.4, 0.4]]])
+ weights_max = constant_op.constant(
+ [[[0.1, 0.0, 0.2, 0.0], [0.0, -0.1, 0.0, -0.2], [0.3, 0.0, 0.4, 0.0],
+ [0.0, -0.3, 0.0, -0.4]]])
+ expected_mask = [[[0, 0, 0, 0], [0, 0, 0, 0], [1, 1, 1, 1], [1, 1, 1, 1]]]
+
+ self._blockMasking(param_list + ["block_pooling_function=MAX"], weights_max,
+ expected_mask)
self._blockMasking(param_list + ["block_pooling_function=AVG"],
weights_avg, expected_mask)