aboutsummaryrefslogtreecommitdiffhomepage
diff options
context:
space:
mode:
authorGravatar A. Unique TensorFlower <gardener@tensorflow.org>2018-01-08 16:08:44 -0800
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2018-01-08 16:17:00 -0800
commit5976ab9b91ee6e236335ba4a322f5a514b29da7f (patch)
tree3599d95bac3d5ddbd895c067320fd2a6e0863a16
parentcab02277fbabde3140fc0d9a198ff30fe40e8f36 (diff)
Enable _warmstart() to work with un-partitioned variables.
PiperOrigin-RevId: 181233898
-rw-r--r--tensorflow/python/estimator/warm_starting_util.py5
-rw-r--r--tensorflow/python/estimator/warm_starting_util_test.py61
2 files changed, 66 insertions, 0 deletions
diff --git a/tensorflow/python/estimator/warm_starting_util.py b/tensorflow/python/estimator/warm_starting_util.py
index 5830251f6a..476776daa8 100644
--- a/tensorflow/python/estimator/warm_starting_util.py
+++ b/tensorflow/python/estimator/warm_starting_util.py
@@ -420,5 +420,10 @@ def _warmstart(warmstart_settings):
if warmstart_settings.vars_to_warmstart:
logging.info("Warm-starting variable: {}; prev_var_name: {}".format(
var_name, prev_var_name or "Unchanged"))
+ # Because we use a default empty list in grouped_variables, single
+ # unpartitioned variables will be lists here, which we rectify in order
+ # for init_from_checkpoint logic to work correctly.
+ if len(variable) == 1:
+ variable = variable[0]
_warmstart_var(variable, warmstart_settings.ckpt_to_initialize_from,
prev_var_name)
diff --git a/tensorflow/python/estimator/warm_starting_util_test.py b/tensorflow/python/estimator/warm_starting_util_test.py
index 18a70c530c..cf502dd60d 100644
--- a/tensorflow/python/estimator/warm_starting_util_test.py
+++ b/tensorflow/python/estimator/warm_starting_util_test.py
@@ -659,6 +659,67 @@ class WarmStartingUtilTest(test.TestCase):
]
}, sess)
+ def testWarmStartMoreSettingsNoPartitioning(self):
+ # Create old and new vocabs for sparse column "sc_vocab".
+ prev_vocab_path = self._write_vocab(["apple", "banana", "guava", "orange"],
+ "old_vocab")
+ new_vocab_path = self._write_vocab(
+ ["orange", "guava", "banana", "apple", "raspberry",
+ "blueberry"], "new_vocab")
+ # Create feature columns.
+ sc_hash = fc.categorical_column_with_hash_bucket(
+ "sc_hash", hash_bucket_size=15)
+ sc_keys = fc.categorical_column_with_vocabulary_list(
+ "sc_keys", vocabulary_list=["a", "b", "c", "e"])
+ sc_vocab = fc.categorical_column_with_vocabulary_file(
+ "sc_vocab", vocabulary_file=new_vocab_path, vocabulary_size=6)
+ all_linear_cols = [sc_hash, sc_keys, sc_vocab]
+
+ # Save checkpoint from which to warm-start.
+ with ops.Graph().as_default() as g:
+ with self.test_session(graph=g) as sess:
+ variable_scope.get_variable(
+ "linear_model/sc_hash/weights", shape=[15, 1], initializer=norms())
+ sc_keys_weights = variable_scope.get_variable(
+ "some_other_name", shape=[4, 1], initializer=rand())
+ variable_scope.get_variable(
+ "linear_model/sc_vocab/weights",
+ initializer=[[0.5], [1.], [2.], [3.]])
+ self._write_checkpoint(sess)
+ prev_keys_val = sess.run(sc_keys_weights)
+
+ # New graph, new session with warmstarting.
+ with ops.Graph().as_default() as g:
+ with self.test_session(graph=g) as sess:
+ cols_to_vars = self._create_linear_model(all_linear_cols,
+ partitioner=None)
+ vocab_info = ws_util._VocabInfo(
+ new_vocab=sc_vocab.vocabulary_file,
+ new_vocab_size=sc_vocab.vocabulary_size,
+ num_oov_buckets=sc_vocab.num_oov_buckets,
+ old_vocab=prev_vocab_path
+ )
+ ws_settings = ws_util._WarmStartSettings(
+ self.get_temp_dir(),
+ vars_to_warmstart=".*(sc_keys|sc_vocab).*",
+ var_name_to_vocab_info={
+ ws_util._infer_var_name(cols_to_vars[sc_vocab]): vocab_info
+ },
+ var_name_to_prev_var_name={
+ ws_util._infer_var_name(cols_to_vars[sc_keys]):
+ "some_other_name"
+ })
+ ws_util._warmstart(ws_settings)
+ sess.run(variables.global_variables_initializer())
+ # Verify weights were correctly warmstarted. Var corresponding to
+ # sc_hash should not be warm-started. Var corresponding to sc_vocab
+ # should be correctly warmstarted after vocab remapping.
+ self._assert_cols_to_vars(cols_to_vars, {
+ sc_keys: [prev_keys_val],
+ sc_hash: [np.zeros([15, 1])],
+ sc_vocab: [np.array([[3.], [2.], [1.], [0.5], [0.], [0.]])]
+ }, sess)
+
def testWarmStartVarsToWarmstartIsNone(self):
# Create old and new vocabs for sparse column "sc_vocab".
prev_vocab_path = self._write_vocab(["apple", "banana", "guava", "orange"],