aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/contrib/model_pruning
diff options
context:
space:
mode:
authorGravatar Jesse <jessehagenaars@gmail.com>2018-06-06 16:07:20 +0200
committerGravatar GitHub <noreply@github.com>2018-06-06 16:07:20 +0200
commit90b28b7316edb644b71b01edaaa8553d5913fc19 (patch)
tree842d452d3b80b908dd9c169928305213f60d251d /tensorflow/contrib/model_pruning
parente106a458dd26db58c7d5abbd4afef60f8ce33252 (diff)
Removed redundant use of enumeration
Since every mask has an accompanying threshold, zip(masks, thresholds) can be used instead of enumerate(masks) and calling thresholds by index.
Diffstat (limited to 'tensorflow/contrib/model_pruning')
-rw-r--r--tensorflow/contrib/model_pruning/python/pruning.py5
1 files changed, 2 insertions, 3 deletions
diff --git a/tensorflow/contrib/model_pruning/python/pruning.py b/tensorflow/contrib/model_pruning/python/pruning.py
index e6f9acc139..d843fa26d5 100644
--- a/tensorflow/contrib/model_pruning/python/pruning.py
+++ b/tensorflow/contrib/model_pruning/python/pruning.py
@@ -518,11 +518,10 @@ class Pruning(object):
summary.scalar('last_mask_update_step', self._last_update_step)
masks = get_masks()
thresholds = get_thresholds()
- for index, mask in enumerate(masks):
+ for mask, threshold in zip(masks, thresholds):
if not self._exists_in_do_not_prune_list(mask.name):
summary.scalar(mask.op.name + '/sparsity', nn_impl.zero_fraction(mask))
- summary.scalar(thresholds[index].op.name + '/threshold',
- thresholds[index])
+ summary.scalar(threshold.op.name + '/threshold', threshold)
def print_hparams(self):
logging.info(self._spec.to_json())