aboutsummaryrefslogtreecommitdiffhomepage
diff options
context:
space:
mode:
authorGravatar A. Unique TensorFlower <gardener@tensorflow.org>2018-01-09 18:17:28 -0800
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2018-01-09 18:21:05 -0800
commit5fce499739af0cea5f618ec43eb6b41d45cd72a8 (patch)
treec600d2637d00fc61bd8a441cb941fe7414b70824
parenta329c8824ba53be92644f11e46fde7a3b9d1c42f (diff)
Raise an error if the variable names in WarmStartSettings aren't actually used.
PiperOrigin-RevId: 181404919
-rw-r--r--tensorflow/python/estimator/warm_starting_util.py36
-rw-r--r--tensorflow/python/estimator/warm_starting_util_test.py21
2 files changed, 57 insertions, 0 deletions
diff --git a/tensorflow/python/estimator/warm_starting_util.py b/tensorflow/python/estimator/warm_starting_util.py
index 37ac8515cb..fa65f55070 100644
--- a/tensorflow/python/estimator/warm_starting_util.py
+++ b/tensorflow/python/estimator/warm_starting_util.py
@@ -377,6 +377,12 @@ def _warmstart(warmstart_settings):
Args:
warmstart_settings: An object of `_WarmStartSettings`.
+
+ Raises:
+ ValueError: If the WarmStartSettings contains prev_var_name or VocabInfo
+ configuration for variable names that are not used. This is to ensure
+ a stronger check for variable configuration than relying on users to
+ examine the logs.
"""
# We have to deal with partitioned variables, since get_collection flattens
# out the list.
@@ -390,10 +396,22 @@ def _warmstart(warmstart_settings):
else:
var_name = _infer_var_name(v)
grouped_variables.setdefault(var_name, []).append(v)
+
+ # Keep track of which var_names in var_name_to_prev_var_name and
+ # var_name_to_vocab_info have been used. Err on the safer side by throwing an
+ # exception if any are unused by the end of the loop. It is easy to misname
+ # a variable during this configuration, in which case without this check, we
+ # would fail to warmstart silently.
+ prev_var_name_used = set()
+ vocab_info_used = set()
+
for var_name, variable in six.iteritems(grouped_variables):
prev_var_name = warmstart_settings.var_name_to_prev_var_name.get(var_name)
+ if prev_var_name:
+ prev_var_name_used.add(var_name)
vocab_info = warmstart_settings.var_name_to_vocab_info.get(var_name)
if vocab_info:
+ vocab_info_used.add(var_name)
logging.info(
"Warm-starting variable: {}; current_vocab: {} current_vocab_size: {}"
" prev_vocab: {} prev_vocab_size: {} current_oov: {} prev_tensor: {}"
@@ -430,3 +448,21 @@ def _warmstart(warmstart_settings):
variable = variable[0]
_warmstart_var(variable, warmstart_settings.ckpt_to_initialize_from,
prev_var_name)
+
+ prev_var_name_not_used = set(
+ warmstart_settings.var_name_to_prev_var_name.keys()) - prev_var_name_used
+ vocab_info_not_used = set(
+ warmstart_settings.var_name_to_vocab_info.keys()) - vocab_info_used
+
+ if prev_var_name_not_used:
+ raise ValueError(
+ "You provided the following variables in "
+ "warmstart_settings.var_name_to_prev_var_name that were not used: {0}. "
+ " Perhaps you misspelled them? Here is the list of viable variable "
+ "names: {1}".format(prev_var_name_not_used, grouped_variables.keys()))
+ if vocab_info_not_used:
+ raise ValueError(
+ "You provided the following variables in "
+ "warmstart_settings.var_name_to_vocab_info that were not used: {0}. "
+ " Perhaps you misspelled them? Here is the list of viable variable "
+ "names: {1}".format(vocab_info_not_used, grouped_variables.keys()))
diff --git a/tensorflow/python/estimator/warm_starting_util_test.py b/tensorflow/python/estimator/warm_starting_util_test.py
index cc0c4efc75..23445a1c37 100644
--- a/tensorflow/python/estimator/warm_starting_util_test.py
+++ b/tensorflow/python/estimator/warm_starting_util_test.py
@@ -992,6 +992,27 @@ class WarmStartingUtilTest(test.TestCase):
self.assertRaises(TypeError, ws_util._warmstart,
{"StringType": x}, ws_util._WarmStartSettings("/tmp"))
+ # Unused variable names raises ValueError.
+ with ops.Graph().as_default():
+ with self.test_session() as sess:
+ x = variable_scope.get_variable(
+ "x",
+ shape=[4, 1],
+ initializer=ones(),
+ partitioner=lambda shape, dtype: [2, 1])
+ self._write_checkpoint(sess)
+
+ self.assertRaises(ValueError, ws_util._warmstart,
+ ws_util._WarmStartSettings(
+ self.get_temp_dir(),
+ var_name_to_vocab_info={
+ "y": ws_util._VocabInfo("", 1, 0, "")
+ }))
+ self.assertRaises(ValueError, ws_util._warmstart,
+ ws_util._WarmStartSettings(
+ self.get_temp_dir(),
+ var_name_to_prev_var_name={"y": "y2"}))
+
if __name__ == "__main__":
test.main()