diff options
author | A. Unique TensorFlower <gardener@tensorflow.org> | 2018-01-09 18:17:28 -0800 |
---|---|---|
committer | TensorFlower Gardener <gardener@tensorflow.org> | 2018-01-09 18:21:05 -0800 |
commit | 5fce499739af0cea5f618ec43eb6b41d45cd72a8 (patch) | |
tree | c600d2637d00fc61bd8a441cb941fe7414b70824 | |
parent | a329c8824ba53be92644f11e46fde7a3b9d1c42f (diff) |
Raise an error if the variable names in WarmStartSettings aren't actually used.
PiperOrigin-RevId: 181404919
-rw-r--r-- | tensorflow/python/estimator/warm_starting_util.py | 36 | ||||
-rw-r--r-- | tensorflow/python/estimator/warm_starting_util_test.py | 21 |
2 files changed, 57 insertions, 0 deletions
diff --git a/tensorflow/python/estimator/warm_starting_util.py b/tensorflow/python/estimator/warm_starting_util.py index 37ac8515cb..fa65f55070 100644 --- a/tensorflow/python/estimator/warm_starting_util.py +++ b/tensorflow/python/estimator/warm_starting_util.py @@ -377,6 +377,12 @@ def _warmstart(warmstart_settings): Args: warmstart_settings: An object of `_WarmStartSettings`. + + Raises: + ValueError: If the WarmStartSettings contains prev_var_name or VocabInfo + configuration for variable names that are not used. This is to ensure + a stronger check for variable configuration than relying on users to + examine the logs. """ # We have to deal with partitioned variables, since get_collection flattens # out the list. @@ -390,10 +396,22 @@ def _warmstart(warmstart_settings): else: var_name = _infer_var_name(v) grouped_variables.setdefault(var_name, []).append(v) + + # Keep track of which var_names in var_name_to_prev_var_name and + # var_name_to_vocab_info have been used. Err on the safer side by throwing an + # exception if any are unused by the end of the loop. It is easy to misname + # a variable during this configuration, in which case without this check, we + # would fail to warmstart silently. + prev_var_name_used = set() + vocab_info_used = set() + for var_name, variable in six.iteritems(grouped_variables): prev_var_name = warmstart_settings.var_name_to_prev_var_name.get(var_name) + if prev_var_name: + prev_var_name_used.add(var_name) vocab_info = warmstart_settings.var_name_to_vocab_info.get(var_name) if vocab_info: + vocab_info_used.add(var_name) logging.info( "Warm-starting variable: {}; current_vocab: {} current_vocab_size: {}" " prev_vocab: {} prev_vocab_size: {} current_oov: {} prev_tensor: {}" @@ -430,3 +448,21 @@ def _warmstart(warmstart_settings): variable = variable[0] _warmstart_var(variable, warmstart_settings.ckpt_to_initialize_from, prev_var_name) + + prev_var_name_not_used = set( + warmstart_settings.var_name_to_prev_var_name.keys()) - prev_var_name_used + vocab_info_not_used = set( + warmstart_settings.var_name_to_vocab_info.keys()) - vocab_info_used + + if prev_var_name_not_used: + raise ValueError( + "You provided the following variables in " + "warmstart_settings.var_name_to_prev_var_name that were not used: {0}. " + " Perhaps you misspelled them? Here is the list of viable variable " + "names: {1}".format(prev_var_name_not_used, grouped_variables.keys())) + if vocab_info_not_used: + raise ValueError( + "You provided the following variables in " + "warmstart_settings.var_name_to_vocab_info that were not used: {0}. " + " Perhaps you misspelled them? Here is the list of viable variable " + "names: {1}".format(vocab_info_not_used, grouped_variables.keys())) diff --git a/tensorflow/python/estimator/warm_starting_util_test.py b/tensorflow/python/estimator/warm_starting_util_test.py index cc0c4efc75..23445a1c37 100644 --- a/tensorflow/python/estimator/warm_starting_util_test.py +++ b/tensorflow/python/estimator/warm_starting_util_test.py @@ -992,6 +992,27 @@ class WarmStartingUtilTest(test.TestCase): self.assertRaises(TypeError, ws_util._warmstart, {"StringType": x}, ws_util._WarmStartSettings("/tmp")) + # Unused variable names raises ValueError. + with ops.Graph().as_default(): + with self.test_session() as sess: + x = variable_scope.get_variable( + "x", + shape=[4, 1], + initializer=ones(), + partitioner=lambda shape, dtype: [2, 1]) + self._write_checkpoint(sess) + + self.assertRaises(ValueError, ws_util._warmstart, + ws_util._WarmStartSettings( + self.get_temp_dir(), + var_name_to_vocab_info={ + "y": ws_util._VocabInfo("", 1, 0, "") + })) + self.assertRaises(ValueError, ws_util._warmstart, + ws_util._WarmStartSettings( + self.get_temp_dir(), + var_name_to_prev_var_name={"y": "y2"})) + if __name__ == "__main__": test.main() |