aboutsummaryrefslogtreecommitdiffhomepage
diff options
context:
space:
mode:
authorGravatar A. Unique TensorFlower <gardener@tensorflow.org>2016-12-02 06:59:21 -0800
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2016-12-02 07:07:49 -0800
commitc499f4211d1a963d1d123f91632f48a056de7eaa (patch)
treed598b6426a92ce292358cf64e9ca3b458d622bab
parent41b616baaef5271917b016a0493a9f87217607c4 (diff)
Added convenience function to filter a list of variables based on regular expressions.
Change: 140845713
-rw-r--r--tensorflow/contrib/framework/python/ops/variables.py63
-rw-r--r--tensorflow/contrib/framework/python/ops/variables_test.py103
2 files changed, 163 insertions, 3 deletions
diff --git a/tensorflow/contrib/framework/python/ops/variables.py b/tensorflow/contrib/framework/python/ops/variables.py
index 2db91cd889..c6425e8d24 100644
--- a/tensorflow/contrib/framework/python/ops/variables.py
+++ b/tensorflow/contrib/framework/python/ops/variables.py
@@ -19,6 +19,8 @@ from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
+import re
+
from tensorflow.contrib.framework.python.ops import add_arg_scope as contrib_add_arg_scope
from tensorflow.contrib.framework.python.ops import gen_variable_ops
from tensorflow.contrib.util import loader
@@ -46,6 +48,7 @@ __all__ = ['add_model_variable',
'assign_from_values',
'assign_from_values_fn',
'create_global_step',
+ 'filter_variables',
'get_global_step',
'get_or_create_global_step',
'get_local_variables',
@@ -624,3 +627,63 @@ class VariableDeviceChooser(object):
device_spec.job = self._job_name
device_spec.task = task_id
return device_spec.to_string()
+
+
+def filter_variables(var_list, include_patterns=None, exclude_patterns=None,
+ reg_search=True):
+ """Filter a list of variables using regular expressions.
+
+ First includes variables according to the list of include_patterns.
+ Afterwards, eliminates variables according to the list of exclude_patterns.
+
+ For example, one can obtain a list of variables with the weights of all
+ convolutional layers (depending on the network definition) by:
+
+ ```python
+ variables = tf.contrib.framework.get_model_variables()
+ conv_weight_variables = tf.contrib.framework.filter_variables(
+ variables,
+ include_patterns=['Conv'],
+ exclude_patterns=['biases', 'Logits'])
+ ```
+
+ Args:
+ var_list: list of variables.
+ include_patterns: list of regular expressions to include. Defaults to None,
+ which means all variables are selected according to the include rules.
+ A variable is included if it matches any of the include_patterns.
+ exclude_patterns: list of regular expressions to exclude. Defaults to None,
+ which means all variables are selected according to the exclude rules.
+ A variable is excluded if it matches any of the exclude_patterns.
+ reg_search: boolean. If True (default), performs re.search to find matches
+ (i.e. pattern can match any substring of the variable name). If False,
+ performs re.match (i.e. regexp should match from the beginning of the
+ variable name).
+
+ Returns:
+ filtered list of variables.
+ """
+ if reg_search:
+ reg_exp_func = re.search
+ else:
+ reg_exp_func = re.match
+
+ # First include variables.
+ if include_patterns is None:
+ included_variables = list(var_list)
+ else:
+ included_variables = []
+ for var in var_list:
+ if any(reg_exp_func(ptrn, var.name) for ptrn in include_patterns):
+ included_variables.append(var)
+
+ # Afterwards, exclude variables.
+ if exclude_patterns is None:
+ filtered_variables = included_variables
+ else:
+ filtered_variables = []
+ for var in included_variables:
+ if not any(reg_exp_func(ptrn, var.name) for ptrn in exclude_patterns):
+ filtered_variables.append(var)
+
+ return filtered_variables
diff --git a/tensorflow/contrib/framework/python/ops/variables_test.py b/tensorflow/contrib/framework/python/ops/variables_test.py
index 191291cbbf..d0264678da 100644
--- a/tensorflow/contrib/framework/python/ops/variables_test.py
+++ b/tensorflow/contrib/framework/python/ops/variables_test.py
@@ -1094,17 +1094,18 @@ class AssignFromCheckpointFnTest(tf.test.TestCase):
self.assertEqual(init_value0, var0.eval())
self.assertEqual(init_value1, var1.eval())
+
class ZeroInitializerOpTest(tf.test.TestCase):
def _testZeroInitializer(self, shape, initializer, use_init):
var = tf.Variable(initializer)
var_zero = tf.contrib.framework.zero_initializer(var)
with self.test_session() as sess:
- with self.assertRaisesOpError("Attempting to use uninitialized value"):
+ with self.assertRaisesOpError('Attempting to use uninitialized value'):
var.eval()
if use_init:
sess.run(var.initializer)
- with self.assertRaisesOpError("input is already initialized"):
+ with self.assertRaisesOpError('input is already initialized'):
var_zero.eval()
self.assertAllClose(np.ones(shape), var.eval())
else:
@@ -1115,7 +1116,103 @@ class ZeroInitializerOpTest(tf.test.TestCase):
for dtype in (tf.int32, tf.int64, tf.float32, tf.float64):
for use_init in (False, True):
self._testZeroInitializer(
- [10, 20], tf.ones([10, 20], dtype = dtype), use_init)
+ [10, 20], tf.ones([10, 20], dtype=dtype), use_init)
+
+
+class FilterVariablesTest(tf.test.TestCase):
+
+ def setUp(self):
+ g = tf.Graph()
+ with g.as_default():
+ var_list = []
+ var_list.append(tf.Variable(0, name='conv1/weights'))
+ var_list.append(tf.Variable(0, name='conv1/biases'))
+ var_list.append(tf.Variable(0, name='conv2/weights'))
+ var_list.append(tf.Variable(0, name='conv2/biases'))
+ var_list.append(tf.Variable(0, name='clfs/weights'))
+ var_list.append(tf.Variable(0, name='clfs/biases'))
+ self._var_list = var_list
+
+ def _test_filter_variables(self, expected_var_names, include_patterns=None,
+ exclude_patterns=None, reg_search=True):
+ filtered_var_list = tf.contrib.framework.filter_variables(
+ self._var_list,
+ include_patterns=include_patterns,
+ exclude_patterns=exclude_patterns,
+ reg_search=reg_search)
+
+ filtered_var_names = [var.op.name for var in filtered_var_list]
+
+ for name in filtered_var_names:
+ self.assertIn(name, expected_var_names)
+ for name in expected_var_names:
+ self.assertIn(name, filtered_var_names)
+ self.assertEqual(len(filtered_var_names), len(expected_var_names))
+
+ def testNoFiltering(self):
+ self._test_filter_variables(
+ expected_var_names=[
+ 'conv1/weights',
+ 'conv1/biases',
+ 'conv2/weights',
+ 'conv2/biases',
+ 'clfs/weights',
+ 'clfs/biases'])
+
+ def testIncludeBiases(self):
+ self._test_filter_variables(
+ expected_var_names=[
+ 'conv1/biases',
+ 'conv2/biases',
+ 'clfs/biases'],
+ include_patterns=['biases'])
+
+ def testExcludeWeights(self):
+ self._test_filter_variables(
+ expected_var_names=[
+ 'conv1/biases',
+ 'conv2/biases',
+ 'clfs/biases'],
+ exclude_patterns=['weights'])
+
+ def testExcludeWeightsAndConv1(self):
+ self._test_filter_variables(
+ expected_var_names=[
+ 'conv2/biases',
+ 'clfs/biases'],
+ exclude_patterns=['weights', 'conv1'])
+
+ def testTwoIncludePatternsEnsureNoVariablesTwiceInFilteredList(self):
+ self._test_filter_variables(
+ expected_var_names=[
+ 'conv1/weights',
+ 'conv1/biases',
+ 'conv2/weights',
+ 'clfs/weights'],
+ include_patterns=['conv1', 'weights'])
+
+ def testIncludeConv1ExcludeBiases(self):
+ self._test_filter_variables(
+ expected_var_names=[
+ 'conv1/weights'],
+ include_patterns=['conv1'],
+ exclude_patterns=['biases'])
+
+ def testRegMatchIncludeBiases(self):
+ self._test_filter_variables(
+ expected_var_names=[
+ 'conv1/biases',
+ 'conv2/biases',
+ 'clfs/biases'],
+ include_patterns=['.*biases'],
+ reg_search=False)
+
+ def testRegMatchIncludeBiasesWithIncompleteRegExpHasNoMatches(self):
+ self._test_filter_variables(
+ expected_var_names=[],
+ include_patterns=['biases'],
+ reg_search=False)
+
if __name__ == '__main__':
tf.test.main()