diff options
author | 2016-12-02 06:59:21 -0800 | |
---|---|---|
committer | 2016-12-02 07:07:49 -0800 | |
commit | c499f4211d1a963d1d123f91632f48a056de7eaa (patch) | |
tree | d598b6426a92ce292358cf64e9ca3b458d622bab | |
parent | 41b616baaef5271917b016a0493a9f87217607c4 (diff) |
Added convenience function to filter a list of variables based on regular expressions.
Change: 140845713
-rw-r--r-- | tensorflow/contrib/framework/python/ops/variables.py | 63 | ||||
-rw-r--r-- | tensorflow/contrib/framework/python/ops/variables_test.py | 103 |
2 files changed, 163 insertions, 3 deletions
diff --git a/tensorflow/contrib/framework/python/ops/variables.py b/tensorflow/contrib/framework/python/ops/variables.py index 2db91cd889..c6425e8d24 100644 --- a/tensorflow/contrib/framework/python/ops/variables.py +++ b/tensorflow/contrib/framework/python/ops/variables.py @@ -19,6 +19,8 @@ from __future__ import absolute_import from __future__ import division from __future__ import print_function +import re + from tensorflow.contrib.framework.python.ops import add_arg_scope as contrib_add_arg_scope from tensorflow.contrib.framework.python.ops import gen_variable_ops from tensorflow.contrib.util import loader @@ -46,6 +48,7 @@ __all__ = ['add_model_variable', 'assign_from_values', 'assign_from_values_fn', 'create_global_step', + 'filter_variables', 'get_global_step', 'get_or_create_global_step', 'get_local_variables', @@ -624,3 +627,63 @@ class VariableDeviceChooser(object): device_spec.job = self._job_name device_spec.task = task_id return device_spec.to_string() + + +def filter_variables(var_list, include_patterns=None, exclude_patterns=None, + reg_search=True): + """Filter a list of variables using regular expressions. + + First includes variables according to the list of include_patterns. + Afterwards, eliminates variables according to the list of exclude_patterns. + + For example, one can obtain a list of variables with the weights of all + convolutional layers (depending on the network definition) by: + + ```python + variables = tf.contrib.framework.get_model_variables() + conv_weight_variables = tf.contrib.framework.filter_variables( + variables, + include_patterns=['Conv'], + exclude_patterns=['biases', 'Logits']) + ``` + + Args: + var_list: list of variables. + include_patterns: list of regular expressions to include. Defaults to None, + which means all variables are selected according to the include rules. + A variable is included if it matches any of the include_patterns. + exclude_patterns: list of regular expressions to exclude. Defaults to None, + which means all variables are selected according to the exclude rules. + A variable is excluded if it matches any of the exclude_patterns. + reg_search: boolean. If True (default), performs re.search to find matches + (i.e. pattern can match any substring of the variable name). If False, + performs re.match (i.e. regexp should match from the beginning of the + variable name). + + Returns: + filtered list of variables. + """ + if reg_search: + reg_exp_func = re.search + else: + reg_exp_func = re.match + + # First include variables. + if include_patterns is None: + included_variables = list(var_list) + else: + included_variables = [] + for var in var_list: + if any(reg_exp_func(ptrn, var.name) for ptrn in include_patterns): + included_variables.append(var) + + # Afterwards, exclude variables. + if exclude_patterns is None: + filtered_variables = included_variables + else: + filtered_variables = [] + for var in included_variables: + if not any(reg_exp_func(ptrn, var.name) for ptrn in exclude_patterns): + filtered_variables.append(var) + + return filtered_variables diff --git a/tensorflow/contrib/framework/python/ops/variables_test.py b/tensorflow/contrib/framework/python/ops/variables_test.py index 191291cbbf..d0264678da 100644 --- a/tensorflow/contrib/framework/python/ops/variables_test.py +++ b/tensorflow/contrib/framework/python/ops/variables_test.py @@ -1094,17 +1094,18 @@ class AssignFromCheckpointFnTest(tf.test.TestCase): self.assertEqual(init_value0, var0.eval()) self.assertEqual(init_value1, var1.eval()) + class ZeroInitializerOpTest(tf.test.TestCase): def _testZeroInitializer(self, shape, initializer, use_init): var = tf.Variable(initializer) var_zero = tf.contrib.framework.zero_initializer(var) with self.test_session() as sess: - with self.assertRaisesOpError("Attempting to use uninitialized value"): + with self.assertRaisesOpError('Attempting to use uninitialized value'): var.eval() if use_init: sess.run(var.initializer) - with self.assertRaisesOpError("input is already initialized"): + with self.assertRaisesOpError('input is already initialized'): var_zero.eval() self.assertAllClose(np.ones(shape), var.eval()) else: @@ -1115,7 +1116,103 @@ class ZeroInitializerOpTest(tf.test.TestCase): for dtype in (tf.int32, tf.int64, tf.float32, tf.float64): for use_init in (False, True): self._testZeroInitializer( - [10, 20], tf.ones([10, 20], dtype = dtype), use_init) + [10, 20], tf.ones([10, 20], dtype=dtype), use_init) + + +class FilterVariablesTest(tf.test.TestCase): + + def setUp(self): + g = tf.Graph() + with g.as_default(): + var_list = [] + var_list.append(tf.Variable(0, name='conv1/weights')) + var_list.append(tf.Variable(0, name='conv1/biases')) + var_list.append(tf.Variable(0, name='conv2/weights')) + var_list.append(tf.Variable(0, name='conv2/biases')) + var_list.append(tf.Variable(0, name='clfs/weights')) + var_list.append(tf.Variable(0, name='clfs/biases')) + self._var_list = var_list + + def _test_filter_variables(self, expected_var_names, include_patterns=None, + exclude_patterns=None, reg_search=True): + filtered_var_list = tf.contrib.framework.filter_variables( + self._var_list, + include_patterns=include_patterns, + exclude_patterns=exclude_patterns, + reg_search=reg_search) + + filtered_var_names = [var.op.name for var in filtered_var_list] + + for name in filtered_var_names: + self.assertIn(name, expected_var_names) + for name in expected_var_names: + self.assertIn(name, filtered_var_names) + self.assertEqual(len(filtered_var_names), len(expected_var_names)) + + def testNoFiltering(self): + self._test_filter_variables( + expected_var_names=[ + 'conv1/weights', + 'conv1/biases', + 'conv2/weights', + 'conv2/biases', + 'clfs/weights', + 'clfs/biases']) + + def testIncludeBiases(self): + self._test_filter_variables( + expected_var_names=[ + 'conv1/biases', + 'conv2/biases', + 'clfs/biases'], + include_patterns=['biases']) + + def testExcludeWeights(self): + self._test_filter_variables( + expected_var_names=[ + 'conv1/biases', + 'conv2/biases', + 'clfs/biases'], + exclude_patterns=['weights']) + + def testExcludeWeightsAndConv1(self): + self._test_filter_variables( + expected_var_names=[ + 'conv2/biases', + 'clfs/biases'], + exclude_patterns=['weights', 'conv1']) + + def testTwoIncludePatternsEnsureNoVariablesTwiceInFilteredList(self): + self._test_filter_variables( + expected_var_names=[ + 'conv1/weights', + 'conv1/biases', + 'conv2/weights', + 'clfs/weights'], + include_patterns=['conv1', 'weights']) + + def testIncludeConv1ExcludeBiases(self): + self._test_filter_variables( + expected_var_names=[ + 'conv1/weights'], + include_patterns=['conv1'], + exclude_patterns=['biases']) + + def testRegMatchIncludeBiases(self): + self._test_filter_variables( + expected_var_names=[ + 'conv1/biases', + 'conv2/biases', + 'clfs/biases'], + include_patterns=['.*biases'], + reg_search=False) + + def testRegMatchIncludeBiasesWithIncompleteRegExpHasNoMatches(self): + self._test_filter_variables( + expected_var_names=[], + include_patterns=['biases'], + reg_search=False) + if __name__ == '__main__': tf.test.main() |