diff options
author | A. Unique TensorFlower <gardener@tensorflow.org> | 2016-07-26 15:50:03 -0800 |
---|---|---|
committer | TensorFlower Gardener <gardener@tensorflow.org> | 2016-07-26 17:03:39 -0700 |
commit | d3c0b946631379881083e0ce9f71faaadd651d34 (patch) | |
tree | 079d51851e12ff404cefd84c0110a0c4ad6f6904 | |
parent | 91d404476acc79b680685105b9d0eae0704a87d0 (diff) |
Add ability to warm start embedding weights and crossed column weights from a provided checkpoint.
Change: 128531220
-rw-r--r-- | tensorflow/contrib/layers/python/layers/feature_column.py | 100 | ||||
-rw-r--r-- | tensorflow/contrib/layers/python/layers/feature_column_test.py | 103 |
2 files changed, 187 insertions, 16 deletions
diff --git a/tensorflow/contrib/layers/python/layers/feature_column.py b/tensorflow/contrib/layers/python/layers/feature_column.py index f4fedd766d..26d87d8dd3 100644 --- a/tensorflow/contrib/layers/python/layers/feature_column.py +++ b/tensorflow/contrib/layers/python/layers/feature_column.py @@ -75,6 +75,7 @@ import abc import collections import math +from tensorflow.contrib.framework.python.framework import checkpoint_utils from tensorflow.contrib.framework.python.ops import variables as contrib_variables from tensorflow.contrib.layers.python.layers import embedding_ops from tensorflow.contrib.layers.python.ops import bucketization_op @@ -149,6 +150,7 @@ class _FeatureColumn(object): raise ValueError("Calling an abstract method.") +# TODO(b/30410315): Support warm starting in all feature columns. class _SparseColumn(_FeatureColumn, collections.namedtuple("_SparseColumn", ["column_name", "is_integerized", @@ -568,7 +570,8 @@ def weighted_sparse_column(sparse_id_column, class _EmbeddingColumn(_FeatureColumn, collections.namedtuple( "_EmbeddingColumn", - ["sparse_id_column", "dimension", "combiner", "initializer"])): + ["sparse_id_column", "dimension", "combiner", "initializer", + "ckpt_to_load_from", "tensor_name_in_ckpt"])): """Represents an embedding column. Args: @@ -586,15 +589,31 @@ class _EmbeddingColumn(_FeatureColumn, collections.namedtuple( variable initialization. If not specified, defaults to `tf.truncated_normal_initializer` with mean 0.0 and standard deviation 1/sqrt(sparse_id_column.length). + ckpt_to_load_from: (Optional). String representing checkpoint name/pattern + to restore the column weights. Required if `tensor_name_in_ckpt` is not + None. + tensor_name_in_ckpt: (Optional). Name of the `Tensor` in the provided + checkpoint from which to restore the column weights. Required if + `ckpt_to_load_from` is not None. + + Raises: + ValueError: if `initializer` is specified and is not callable. Also, + if only one of `ckpt_to_load_from` and `tensor_name_in_ckpt` is specified. """ def __new__(cls, sparse_id_column, dimension, combiner="mean", - initializer=None): + initializer=None, + ckpt_to_load_from=None, + tensor_name_in_ckpt=None): if initializer is not None and not callable(initializer): raise ValueError("initializer must be callable if specified.") + + if (ckpt_to_load_from is None) != (tensor_name_in_ckpt is None): + raise ValueError("Must specify both `ckpt_to_load_from` and " + "`tensor_name_in_ckpt` or none of them.") if initializer is None: stddev = 1 / math.sqrt(sparse_id_column.length) # TODO(b/25671353): Better initial value? @@ -602,7 +621,8 @@ class _EmbeddingColumn(_FeatureColumn, collections.namedtuple( stddev=stddev) return super(_EmbeddingColumn, cls).__new__(cls, sparse_id_column, dimension, combiner, - initializer) + initializer, ckpt_to_load_from, + tensor_name_in_ckpt) @property def name(self): @@ -645,7 +665,7 @@ class _EmbeddingColumn(_FeatureColumn, collections.namedtuple( input_tensor, weight_collections=None, trainable=True): - output, _ = _create_embedding_lookup( + output, embedding_weights = _create_embedding_lookup( input_tensor=self.sparse_id_column.id_tensor(input_tensor), weight_tensor=self.sparse_id_column.weight_tensor(input_tensor), vocab_size=self.length, @@ -655,6 +675,13 @@ class _EmbeddingColumn(_FeatureColumn, collections.namedtuple( combiner=self.combiner, trainable=trainable, name=self.name + "_weights") + if self.ckpt_to_load_from is not None: + weights_to_restore = embedding_weights + if len(embedding_weights) == 1: + weights_to_restore = embedding_weights[0] + checkpoint_utils.init_from_checkpoint( + self.ckpt_to_load_from, + {self.tensor_name_in_ckpt: weights_to_restore}) return output # pylint: disable=unused-argument @@ -670,7 +697,9 @@ class _EmbeddingColumn(_FeatureColumn, collections.namedtuple( def embedding_column(sparse_id_column, dimension, combiner="mean", - initializer=None): + initializer=None, + ckpt_to_load_from=None, + tensor_name_in_ckpt=None): """Creates an _EmbeddingColumn. Args: @@ -688,11 +717,18 @@ def embedding_column(sparse_id_column, variable initialization. If not specified, defaults to `tf.truncated_normal_initializer` with mean 0.0 and standard deviation 1/sqrt(sparse_id_column.length). + ckpt_to_load_from: (Optional). String representing checkpoint name/pattern + to restore the column weights. Required if `tensor_name_in_ckpt` is not + None. + tensor_name_in_ckpt: (Optional). Name of the `Tensor` in the provided + checkpoint from which to restore the column weights. Required if + `ckpt_to_load_from` is not None. Returns: An _EmbeddingColumn. """ - return _EmbeddingColumn(sparse_id_column, dimension, combiner, initializer) + return _EmbeddingColumn(sparse_id_column, dimension, combiner, initializer, + ckpt_to_load_from, tensor_name_in_ckpt) class _HashedEmbeddingColumn(collections.namedtuple( @@ -1087,7 +1123,8 @@ def bucketized_column(source_column, boundaries): class _CrossedColumn(_FeatureColumn, collections.namedtuple( - "_CrossedColumn", ["columns", "hash_bucket_size", "combiner"])): + "_CrossedColumn", ["columns", "hash_bucket_size", "combiner", + "ckpt_to_load_from", "tensor_name_in_ckpt"])): """Represents a cross transformation also known as composition or union. Instances of this class are immutable. It crosses given `columns`. Crossed @@ -1124,13 +1161,19 @@ class _CrossedColumn(_FeatureColumn, collections.namedtuple( * "mean": do l1 normalization * "sqrtn": do l2 normalization For more information: `tf.embedding_lookup_sparse`. + ckpt_to_load_from: (Optional). String representing checkpoint name/pattern + to restore the column weights. Required if `tensor_name_in_ckpt` is not + None. + tensor_name_in_ckpt: (Optional). Name of the `Tensor` in the provided + checkpoint from which to restore the column weights. Required if + `ckpt_to_load_from` is not None. Raises: TypeError: if all items in columns are not an instance of _SparseColumn, _CrossedColumn, or _BucketizedColumn or hash_bucket_size is not an int. - ValueError: if hash_bucket_size is not > 1 or - len(columns) is not > 1. + ValueError: if hash_bucket_size is not > 1 or len(columns) is not > 1. Also, + if only one of `ckpt_to_load_from` and `tensor_name_in_ckpt` is specified. """ @staticmethod @@ -1138,7 +1181,8 @@ class _CrossedColumn(_FeatureColumn, collections.namedtuple( return isinstance(column, (_SparseColumn, _CrossedColumn, _BucketizedColumn)) - def __new__(cls, columns, hash_bucket_size, combiner="sum"): + def __new__(cls, columns, hash_bucket_size, combiner="sum", + ckpt_to_load_from=None, tensor_name_in_ckpt=None): for column in columns: if not _CrossedColumn._is_crossable(column): raise TypeError("columns should be a set of " @@ -1154,10 +1198,16 @@ class _CrossedColumn(_FeatureColumn, collections.namedtuple( if hash_bucket_size < 2: raise ValueError("hash_bucket_size should be at least 2.") + if (ckpt_to_load_from is None) != (tensor_name_in_ckpt is None): + raise ValueError("Must specify both `ckpt_to_load_from` and " + "`tensor_name_in_ckpt` or none of them.") + sorted_columns = sorted([column for column in columns], key=lambda column: column.name) return super(_CrossedColumn, cls).__new__(cls, tuple(sorted_columns), - hash_bucket_size, combiner) + hash_bucket_size, combiner, + ckpt_to_load_from, + tensor_name_in_ckpt) @property def name(self): @@ -1223,7 +1273,7 @@ class _CrossedColumn(_FeatureColumn, collections.namedtuple( num_outputs=1, weight_collections=None, trainable=True): - return _create_embedding_lookup( + output, embedding_weights = _create_embedding_lookup( input_tensor=input_tensor, weight_tensor=None, vocab_size=self.length, @@ -1233,9 +1283,19 @@ class _CrossedColumn(_FeatureColumn, collections.namedtuple( combiner=self.combiner, trainable=trainable, name=self.name + "_weights") - - -def crossed_column(columns, hash_bucket_size, combiner="sum"): + if self.ckpt_to_load_from is not None: + weights_to_restore = embedding_weights + if len(embedding_weights) == 1: + weights_to_restore = embedding_weights[0] + checkpoint_utils.init_from_checkpoint( + self.ckpt_to_load_from, + {self.tensor_name_in_ckpt: weights_to_restore}) + return output, embedding_weights + + +def crossed_column(columns, hash_bucket_size, combiner="sum", + ckpt_to_load_from=None, + tensor_name_in_ckpt=None): """Creates a _CrossedColumn. Args: @@ -1243,6 +1303,12 @@ def crossed_column(columns, hash_bucket_size, combiner="sum"): _SparseColumn, _CrossedColumn, or _BucketizedColumn. hash_bucket_size: An int that is > 1. The number of buckets. combiner: A combiner string, supports sum, mean, sqrtn. + ckpt_to_load_from: (Optional). String representing checkpoint name/pattern + to restore the column weights. Required if `tensor_name_in_ckpt` is not + None. + tensor_name_in_ckpt: (Optional). Name of the `Tensor` in the provided + checkpoint from which to restore the column weights. Required if + `ckpt_to_load_from` is not None. Returns: A _CrossedColumn. @@ -1254,7 +1320,9 @@ def crossed_column(columns, hash_bucket_size, combiner="sum"): ValueError: if hash_bucket_size is not > 1 or len(columns) is not > 1. """ - return _CrossedColumn(columns, hash_bucket_size, combiner=combiner) + return _CrossedColumn(columns, hash_bucket_size, combiner=combiner, + ckpt_to_load_from=ckpt_to_load_from, + tensor_name_in_ckpt=tensor_name_in_ckpt) class DataFrameColumn(_FeatureColumn, diff --git a/tensorflow/contrib/layers/python/layers/feature_column_test.py b/tensorflow/contrib/layers/python/layers/feature_column_test.py index 187fadfce6..86d522dedf 100644 --- a/tensorflow/contrib/layers/python/layers/feature_column_test.py +++ b/tensorflow/contrib/layers/python/layers/feature_column_test.py @@ -19,6 +19,8 @@ from __future__ import absolute_import from __future__ import division from __future__ import print_function +import os + import tensorflow as tf @@ -323,6 +325,107 @@ class FeatureColumnTest(tf.test.TestCase): self.assertEqual(tf.float32, placeholder.dtype) self.assertEqual([None, 1], placeholder.get_shape().as_list()) + def testInitEmbeddingColumnWeightsFromCkpt(self): + sparse_col = tf.contrib.layers.sparse_column_with_hash_bucket( + column_name="object_in_image", + hash_bucket_size=4) + # Create _EmbeddingColumn which randomly initializes embedding of size + # [4, 16]. + embedding_col = tf.contrib.layers.embedding_column(sparse_col, dimension=16) + + # Creating a SparseTensor which has all the ids possible for the given + # vocab. + input_tensor = tf.SparseTensor(indices=[[0, 0], [1, 1], [2, 2], [3, 3]], + values=[0, 1, 2, 3], + shape=[4, 4]) + + # Invoking 'embedding_column.to_dnn_input_layer' will create the embedding + # variable. Creating under scope 'run_1' so as to prevent name conflicts + # when creating embedding variable for 'embedding_column_pretrained'. + with tf.variable_scope("run_1"): + # This will return a [4, 16] tensor which is same as embedding variable. + embeddings = embedding_col.to_dnn_input_layer(input_tensor) + + save = tf.train.Saver() + checkpoint_path = os.path.join(self.get_temp_dir(), "model.ckpt") + + with self.test_session() as sess: + sess.run(tf.initialize_all_variables()) + saved_embedding = embeddings.eval() + save.save(sess, checkpoint_path) + + embedding_col_initialized = tf.contrib.layers.embedding_column( + sparse_id_column=sparse_col, + dimension=16, + ckpt_to_load_from=checkpoint_path, + tensor_name_in_ckpt="run_1/object_in_image_embedding_weights") + + with tf.variable_scope("run_2"): + # This will initialize the embedding from provided checkpoint and return a + # [4, 16] tensor which is same as embedding variable. Since we didn't + # modify embeddings, this should be same as 'saved_embedding'. + pretrained_embeddings = embedding_col_initialized.to_dnn_input_layer( + input_tensor) + + with self.test_session() as sess: + sess.run(tf.initialize_all_variables()) + loaded_embedding = pretrained_embeddings.eval() + + self.assertAllClose(saved_embedding, loaded_embedding) + + def testInitCrossedColumnWeightsFromCkpt(self): + sparse_col_1 = tf.contrib.layers.sparse_column_with_hash_bucket( + column_name="col_1", hash_bucket_size=4) + sparse_col_2 = tf.contrib.layers.sparse_column_with_hash_bucket( + column_name="col_2", hash_bucket_size=4) + + crossed_col = tf.contrib.layers.crossed_column( + columns=[sparse_col_1, sparse_col_2], + hash_bucket_size=4) + + input_tensor = tf.SparseTensor(indices=[[0, 0], [1, 1], [2, 2], [3, 3]], + values=[0, 1, 2, 3], + shape=[4, 4]) + + # Invoking 'crossed_col.to_weighted_sum' will create the crossed column + # weights variable. + with tf.variable_scope("run_1"): + # Returns looked up column weights which is same as crossed column weights + # as well as actual references to weights variables. + col_weights, weights = crossed_col.to_weighted_sum(input_tensor) + # Update the weights since default initializer initializes all weights to + # 0.0. + for weight in weights: + assign_op = tf.assign(weight, weight + 0.5) + + save = tf.train.Saver() + checkpoint_path = os.path.join(self.get_temp_dir(), "model.ckpt") + + with self.test_session() as sess: + sess.run(tf.initialize_all_variables()) + sess.run(assign_op) + saved_col_weights = col_weights.eval() + save.save(sess, checkpoint_path) + + crossed_col_initialized = tf.contrib.layers.crossed_column( + columns=[sparse_col_1, sparse_col_2], + hash_bucket_size=4, + ckpt_to_load_from=checkpoint_path, + tensor_name_in_ckpt="run_1/col_1_X_col_2_weights") + + with tf.variable_scope("run_2"): + # This will initialize the crossed column weights from provided checkpoint + # and return a [4, 1] tensor which is same as weights variable. Since we + # won't modify weights, this should be same as 'saved_col_weights'. + col_weights_from_ckpt, _ = crossed_col_initialized.to_weighted_sum( + input_tensor) + + with self.test_session() as sess: + sess.run(tf.initialize_all_variables()) + loaded_col_weights = col_weights_from_ckpt.eval() + + self.assertAllClose(saved_col_weights, loaded_col_weights) + if __name__ == "__main__": tf.test.main() |