diff options
author | 2018-01-08 16:08:44 -0800 | |
---|---|---|
committer | 2018-01-08 16:17:00 -0800 | |
commit | 5976ab9b91ee6e236335ba4a322f5a514b29da7f (patch) | |
tree | 3599d95bac3d5ddbd895c067320fd2a6e0863a16 | |
parent | cab02277fbabde3140fc0d9a198ff30fe40e8f36 (diff) |
Enable _warmstart() to work with un-partitioned variables.
PiperOrigin-RevId: 181233898
-rw-r--r-- | tensorflow/python/estimator/warm_starting_util.py | 5 | ||||
-rw-r--r-- | tensorflow/python/estimator/warm_starting_util_test.py | 61 |
2 files changed, 66 insertions, 0 deletions
diff --git a/tensorflow/python/estimator/warm_starting_util.py b/tensorflow/python/estimator/warm_starting_util.py index 5830251f6a..476776daa8 100644 --- a/tensorflow/python/estimator/warm_starting_util.py +++ b/tensorflow/python/estimator/warm_starting_util.py @@ -420,5 +420,10 @@ def _warmstart(warmstart_settings): if warmstart_settings.vars_to_warmstart: logging.info("Warm-starting variable: {}; prev_var_name: {}".format( var_name, prev_var_name or "Unchanged")) + # Because we use a default empty list in grouped_variables, single + # unpartitioned variables will be lists here, which we rectify in order + # for init_from_checkpoint logic to work correctly. + if len(variable) == 1: + variable = variable[0] _warmstart_var(variable, warmstart_settings.ckpt_to_initialize_from, prev_var_name) diff --git a/tensorflow/python/estimator/warm_starting_util_test.py b/tensorflow/python/estimator/warm_starting_util_test.py index 18a70c530c..cf502dd60d 100644 --- a/tensorflow/python/estimator/warm_starting_util_test.py +++ b/tensorflow/python/estimator/warm_starting_util_test.py @@ -659,6 +659,67 @@ class WarmStartingUtilTest(test.TestCase): ] }, sess) + def testWarmStartMoreSettingsNoPartitioning(self): + # Create old and new vocabs for sparse column "sc_vocab". + prev_vocab_path = self._write_vocab(["apple", "banana", "guava", "orange"], + "old_vocab") + new_vocab_path = self._write_vocab( + ["orange", "guava", "banana", "apple", "raspberry", + "blueberry"], "new_vocab") + # Create feature columns. + sc_hash = fc.categorical_column_with_hash_bucket( + "sc_hash", hash_bucket_size=15) + sc_keys = fc.categorical_column_with_vocabulary_list( + "sc_keys", vocabulary_list=["a", "b", "c", "e"]) + sc_vocab = fc.categorical_column_with_vocabulary_file( + "sc_vocab", vocabulary_file=new_vocab_path, vocabulary_size=6) + all_linear_cols = [sc_hash, sc_keys, sc_vocab] + + # Save checkpoint from which to warm-start. + with ops.Graph().as_default() as g: + with self.test_session(graph=g) as sess: + variable_scope.get_variable( + "linear_model/sc_hash/weights", shape=[15, 1], initializer=norms()) + sc_keys_weights = variable_scope.get_variable( + "some_other_name", shape=[4, 1], initializer=rand()) + variable_scope.get_variable( + "linear_model/sc_vocab/weights", + initializer=[[0.5], [1.], [2.], [3.]]) + self._write_checkpoint(sess) + prev_keys_val = sess.run(sc_keys_weights) + + # New graph, new session with warmstarting. + with ops.Graph().as_default() as g: + with self.test_session(graph=g) as sess: + cols_to_vars = self._create_linear_model(all_linear_cols, + partitioner=None) + vocab_info = ws_util._VocabInfo( + new_vocab=sc_vocab.vocabulary_file, + new_vocab_size=sc_vocab.vocabulary_size, + num_oov_buckets=sc_vocab.num_oov_buckets, + old_vocab=prev_vocab_path + ) + ws_settings = ws_util._WarmStartSettings( + self.get_temp_dir(), + vars_to_warmstart=".*(sc_keys|sc_vocab).*", + var_name_to_vocab_info={ + ws_util._infer_var_name(cols_to_vars[sc_vocab]): vocab_info + }, + var_name_to_prev_var_name={ + ws_util._infer_var_name(cols_to_vars[sc_keys]): + "some_other_name" + }) + ws_util._warmstart(ws_settings) + sess.run(variables.global_variables_initializer()) + # Verify weights were correctly warmstarted. Var corresponding to + # sc_hash should not be warm-started. Var corresponding to sc_vocab + # should be correctly warmstarted after vocab remapping. + self._assert_cols_to_vars(cols_to_vars, { + sc_keys: [prev_keys_val], + sc_hash: [np.zeros([15, 1])], + sc_vocab: [np.array([[3.], [2.], [1.], [0.5], [0.], [0.]])] + }, sess) + def testWarmStartVarsToWarmstartIsNone(self): # Create old and new vocabs for sparse column "sc_vocab". prev_vocab_path = self._write_vocab(["apple", "banana", "guava", "orange"], |