aboutsummaryrefslogtreecommitdiffhomepage
diff options
context:
space:
mode:
authorGravatar A. Unique TensorFlower <gardener@tensorflow.org>2016-07-26 15:50:03 -0800
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2016-07-26 17:03:39 -0700
commitd3c0b946631379881083e0ce9f71faaadd651d34 (patch)
tree079d51851e12ff404cefd84c0110a0c4ad6f6904
parent91d404476acc79b680685105b9d0eae0704a87d0 (diff)
Add ability to warm start embedding weights and crossed column weights from a provided checkpoint.
Change: 128531220
-rw-r--r--tensorflow/contrib/layers/python/layers/feature_column.py100
-rw-r--r--tensorflow/contrib/layers/python/layers/feature_column_test.py103
2 files changed, 187 insertions, 16 deletions
diff --git a/tensorflow/contrib/layers/python/layers/feature_column.py b/tensorflow/contrib/layers/python/layers/feature_column.py
index f4fedd766d..26d87d8dd3 100644
--- a/tensorflow/contrib/layers/python/layers/feature_column.py
+++ b/tensorflow/contrib/layers/python/layers/feature_column.py
@@ -75,6 +75,7 @@ import abc
import collections
import math
+from tensorflow.contrib.framework.python.framework import checkpoint_utils
from tensorflow.contrib.framework.python.ops import variables as contrib_variables
from tensorflow.contrib.layers.python.layers import embedding_ops
from tensorflow.contrib.layers.python.ops import bucketization_op
@@ -149,6 +150,7 @@ class _FeatureColumn(object):
raise ValueError("Calling an abstract method.")
+# TODO(b/30410315): Support warm starting in all feature columns.
class _SparseColumn(_FeatureColumn,
collections.namedtuple("_SparseColumn",
["column_name", "is_integerized",
@@ -568,7 +570,8 @@ def weighted_sparse_column(sparse_id_column,
class _EmbeddingColumn(_FeatureColumn, collections.namedtuple(
"_EmbeddingColumn",
- ["sparse_id_column", "dimension", "combiner", "initializer"])):
+ ["sparse_id_column", "dimension", "combiner", "initializer",
+ "ckpt_to_load_from", "tensor_name_in_ckpt"])):
"""Represents an embedding column.
Args:
@@ -586,15 +589,31 @@ class _EmbeddingColumn(_FeatureColumn, collections.namedtuple(
variable initialization. If not specified, defaults to
`tf.truncated_normal_initializer` with mean 0.0 and standard deviation
1/sqrt(sparse_id_column.length).
+ ckpt_to_load_from: (Optional). String representing checkpoint name/pattern
+ to restore the column weights. Required if `tensor_name_in_ckpt` is not
+ None.
+ tensor_name_in_ckpt: (Optional). Name of the `Tensor` in the provided
+ checkpoint from which to restore the column weights. Required if
+ `ckpt_to_load_from` is not None.
+
+ Raises:
+ ValueError: if `initializer` is specified and is not callable. Also,
+ if only one of `ckpt_to_load_from` and `tensor_name_in_ckpt` is specified.
"""
def __new__(cls,
sparse_id_column,
dimension,
combiner="mean",
- initializer=None):
+ initializer=None,
+ ckpt_to_load_from=None,
+ tensor_name_in_ckpt=None):
if initializer is not None and not callable(initializer):
raise ValueError("initializer must be callable if specified.")
+
+ if (ckpt_to_load_from is None) != (tensor_name_in_ckpt is None):
+ raise ValueError("Must specify both `ckpt_to_load_from` and "
+ "`tensor_name_in_ckpt` or none of them.")
if initializer is None:
stddev = 1 / math.sqrt(sparse_id_column.length)
# TODO(b/25671353): Better initial value?
@@ -602,7 +621,8 @@ class _EmbeddingColumn(_FeatureColumn, collections.namedtuple(
stddev=stddev)
return super(_EmbeddingColumn, cls).__new__(cls, sparse_id_column,
dimension, combiner,
- initializer)
+ initializer, ckpt_to_load_from,
+ tensor_name_in_ckpt)
@property
def name(self):
@@ -645,7 +665,7 @@ class _EmbeddingColumn(_FeatureColumn, collections.namedtuple(
input_tensor,
weight_collections=None,
trainable=True):
- output, _ = _create_embedding_lookup(
+ output, embedding_weights = _create_embedding_lookup(
input_tensor=self.sparse_id_column.id_tensor(input_tensor),
weight_tensor=self.sparse_id_column.weight_tensor(input_tensor),
vocab_size=self.length,
@@ -655,6 +675,13 @@ class _EmbeddingColumn(_FeatureColumn, collections.namedtuple(
combiner=self.combiner,
trainable=trainable,
name=self.name + "_weights")
+ if self.ckpt_to_load_from is not None:
+ weights_to_restore = embedding_weights
+ if len(embedding_weights) == 1:
+ weights_to_restore = embedding_weights[0]
+ checkpoint_utils.init_from_checkpoint(
+ self.ckpt_to_load_from,
+ {self.tensor_name_in_ckpt: weights_to_restore})
return output
# pylint: disable=unused-argument
@@ -670,7 +697,9 @@ class _EmbeddingColumn(_FeatureColumn, collections.namedtuple(
def embedding_column(sparse_id_column,
dimension,
combiner="mean",
- initializer=None):
+ initializer=None,
+ ckpt_to_load_from=None,
+ tensor_name_in_ckpt=None):
"""Creates an _EmbeddingColumn.
Args:
@@ -688,11 +717,18 @@ def embedding_column(sparse_id_column,
variable initialization. If not specified, defaults to
`tf.truncated_normal_initializer` with mean 0.0 and standard deviation
1/sqrt(sparse_id_column.length).
+ ckpt_to_load_from: (Optional). String representing checkpoint name/pattern
+ to restore the column weights. Required if `tensor_name_in_ckpt` is not
+ None.
+ tensor_name_in_ckpt: (Optional). Name of the `Tensor` in the provided
+ checkpoint from which to restore the column weights. Required if
+ `ckpt_to_load_from` is not None.
Returns:
An _EmbeddingColumn.
"""
- return _EmbeddingColumn(sparse_id_column, dimension, combiner, initializer)
+ return _EmbeddingColumn(sparse_id_column, dimension, combiner, initializer,
+ ckpt_to_load_from, tensor_name_in_ckpt)
class _HashedEmbeddingColumn(collections.namedtuple(
@@ -1087,7 +1123,8 @@ def bucketized_column(source_column, boundaries):
class _CrossedColumn(_FeatureColumn, collections.namedtuple(
- "_CrossedColumn", ["columns", "hash_bucket_size", "combiner"])):
+ "_CrossedColumn", ["columns", "hash_bucket_size", "combiner",
+ "ckpt_to_load_from", "tensor_name_in_ckpt"])):
"""Represents a cross transformation also known as composition or union.
Instances of this class are immutable. It crosses given `columns`. Crossed
@@ -1124,13 +1161,19 @@ class _CrossedColumn(_FeatureColumn, collections.namedtuple(
* "mean": do l1 normalization
* "sqrtn": do l2 normalization
For more information: `tf.embedding_lookup_sparse`.
+ ckpt_to_load_from: (Optional). String representing checkpoint name/pattern
+ to restore the column weights. Required if `tensor_name_in_ckpt` is not
+ None.
+ tensor_name_in_ckpt: (Optional). Name of the `Tensor` in the provided
+ checkpoint from which to restore the column weights. Required if
+ `ckpt_to_load_from` is not None.
Raises:
TypeError: if all items in columns are not an instance of _SparseColumn,
_CrossedColumn, or _BucketizedColumn or
hash_bucket_size is not an int.
- ValueError: if hash_bucket_size is not > 1 or
- len(columns) is not > 1.
+ ValueError: if hash_bucket_size is not > 1 or len(columns) is not > 1. Also,
+ if only one of `ckpt_to_load_from` and `tensor_name_in_ckpt` is specified.
"""
@staticmethod
@@ -1138,7 +1181,8 @@ class _CrossedColumn(_FeatureColumn, collections.namedtuple(
return isinstance(column,
(_SparseColumn, _CrossedColumn, _BucketizedColumn))
- def __new__(cls, columns, hash_bucket_size, combiner="sum"):
+ def __new__(cls, columns, hash_bucket_size, combiner="sum",
+ ckpt_to_load_from=None, tensor_name_in_ckpt=None):
for column in columns:
if not _CrossedColumn._is_crossable(column):
raise TypeError("columns should be a set of "
@@ -1154,10 +1198,16 @@ class _CrossedColumn(_FeatureColumn, collections.namedtuple(
if hash_bucket_size < 2:
raise ValueError("hash_bucket_size should be at least 2.")
+ if (ckpt_to_load_from is None) != (tensor_name_in_ckpt is None):
+ raise ValueError("Must specify both `ckpt_to_load_from` and "
+ "`tensor_name_in_ckpt` or none of them.")
+
sorted_columns = sorted([column for column in columns],
key=lambda column: column.name)
return super(_CrossedColumn, cls).__new__(cls, tuple(sorted_columns),
- hash_bucket_size, combiner)
+ hash_bucket_size, combiner,
+ ckpt_to_load_from,
+ tensor_name_in_ckpt)
@property
def name(self):
@@ -1223,7 +1273,7 @@ class _CrossedColumn(_FeatureColumn, collections.namedtuple(
num_outputs=1,
weight_collections=None,
trainable=True):
- return _create_embedding_lookup(
+ output, embedding_weights = _create_embedding_lookup(
input_tensor=input_tensor,
weight_tensor=None,
vocab_size=self.length,
@@ -1233,9 +1283,19 @@ class _CrossedColumn(_FeatureColumn, collections.namedtuple(
combiner=self.combiner,
trainable=trainable,
name=self.name + "_weights")
-
-
-def crossed_column(columns, hash_bucket_size, combiner="sum"):
+ if self.ckpt_to_load_from is not None:
+ weights_to_restore = embedding_weights
+ if len(embedding_weights) == 1:
+ weights_to_restore = embedding_weights[0]
+ checkpoint_utils.init_from_checkpoint(
+ self.ckpt_to_load_from,
+ {self.tensor_name_in_ckpt: weights_to_restore})
+ return output, embedding_weights
+
+
+def crossed_column(columns, hash_bucket_size, combiner="sum",
+ ckpt_to_load_from=None,
+ tensor_name_in_ckpt=None):
"""Creates a _CrossedColumn.
Args:
@@ -1243,6 +1303,12 @@ def crossed_column(columns, hash_bucket_size, combiner="sum"):
_SparseColumn, _CrossedColumn, or _BucketizedColumn.
hash_bucket_size: An int that is > 1. The number of buckets.
combiner: A combiner string, supports sum, mean, sqrtn.
+ ckpt_to_load_from: (Optional). String representing checkpoint name/pattern
+ to restore the column weights. Required if `tensor_name_in_ckpt` is not
+ None.
+ tensor_name_in_ckpt: (Optional). Name of the `Tensor` in the provided
+ checkpoint from which to restore the column weights. Required if
+ `ckpt_to_load_from` is not None.
Returns:
A _CrossedColumn.
@@ -1254,7 +1320,9 @@ def crossed_column(columns, hash_bucket_size, combiner="sum"):
ValueError: if hash_bucket_size is not > 1 or
len(columns) is not > 1.
"""
- return _CrossedColumn(columns, hash_bucket_size, combiner=combiner)
+ return _CrossedColumn(columns, hash_bucket_size, combiner=combiner,
+ ckpt_to_load_from=ckpt_to_load_from,
+ tensor_name_in_ckpt=tensor_name_in_ckpt)
class DataFrameColumn(_FeatureColumn,
diff --git a/tensorflow/contrib/layers/python/layers/feature_column_test.py b/tensorflow/contrib/layers/python/layers/feature_column_test.py
index 187fadfce6..86d522dedf 100644
--- a/tensorflow/contrib/layers/python/layers/feature_column_test.py
+++ b/tensorflow/contrib/layers/python/layers/feature_column_test.py
@@ -19,6 +19,8 @@ from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
+import os
+
import tensorflow as tf
@@ -323,6 +325,107 @@ class FeatureColumnTest(tf.test.TestCase):
self.assertEqual(tf.float32, placeholder.dtype)
self.assertEqual([None, 1], placeholder.get_shape().as_list())
+ def testInitEmbeddingColumnWeightsFromCkpt(self):
+ sparse_col = tf.contrib.layers.sparse_column_with_hash_bucket(
+ column_name="object_in_image",
+ hash_bucket_size=4)
+ # Create _EmbeddingColumn which randomly initializes embedding of size
+ # [4, 16].
+ embedding_col = tf.contrib.layers.embedding_column(sparse_col, dimension=16)
+
+ # Creating a SparseTensor which has all the ids possible for the given
+ # vocab.
+ input_tensor = tf.SparseTensor(indices=[[0, 0], [1, 1], [2, 2], [3, 3]],
+ values=[0, 1, 2, 3],
+ shape=[4, 4])
+
+ # Invoking 'embedding_column.to_dnn_input_layer' will create the embedding
+ # variable. Creating under scope 'run_1' so as to prevent name conflicts
+ # when creating embedding variable for 'embedding_column_pretrained'.
+ with tf.variable_scope("run_1"):
+ # This will return a [4, 16] tensor which is same as embedding variable.
+ embeddings = embedding_col.to_dnn_input_layer(input_tensor)
+
+ save = tf.train.Saver()
+ checkpoint_path = os.path.join(self.get_temp_dir(), "model.ckpt")
+
+ with self.test_session() as sess:
+ sess.run(tf.initialize_all_variables())
+ saved_embedding = embeddings.eval()
+ save.save(sess, checkpoint_path)
+
+ embedding_col_initialized = tf.contrib.layers.embedding_column(
+ sparse_id_column=sparse_col,
+ dimension=16,
+ ckpt_to_load_from=checkpoint_path,
+ tensor_name_in_ckpt="run_1/object_in_image_embedding_weights")
+
+ with tf.variable_scope("run_2"):
+ # This will initialize the embedding from provided checkpoint and return a
+ # [4, 16] tensor which is same as embedding variable. Since we didn't
+ # modify embeddings, this should be same as 'saved_embedding'.
+ pretrained_embeddings = embedding_col_initialized.to_dnn_input_layer(
+ input_tensor)
+
+ with self.test_session() as sess:
+ sess.run(tf.initialize_all_variables())
+ loaded_embedding = pretrained_embeddings.eval()
+
+ self.assertAllClose(saved_embedding, loaded_embedding)
+
+ def testInitCrossedColumnWeightsFromCkpt(self):
+ sparse_col_1 = tf.contrib.layers.sparse_column_with_hash_bucket(
+ column_name="col_1", hash_bucket_size=4)
+ sparse_col_2 = tf.contrib.layers.sparse_column_with_hash_bucket(
+ column_name="col_2", hash_bucket_size=4)
+
+ crossed_col = tf.contrib.layers.crossed_column(
+ columns=[sparse_col_1, sparse_col_2],
+ hash_bucket_size=4)
+
+ input_tensor = tf.SparseTensor(indices=[[0, 0], [1, 1], [2, 2], [3, 3]],
+ values=[0, 1, 2, 3],
+ shape=[4, 4])
+
+ # Invoking 'crossed_col.to_weighted_sum' will create the crossed column
+ # weights variable.
+ with tf.variable_scope("run_1"):
+ # Returns looked up column weights which is same as crossed column weights
+ # as well as actual references to weights variables.
+ col_weights, weights = crossed_col.to_weighted_sum(input_tensor)
+ # Update the weights since default initializer initializes all weights to
+ # 0.0.
+ for weight in weights:
+ assign_op = tf.assign(weight, weight + 0.5)
+
+ save = tf.train.Saver()
+ checkpoint_path = os.path.join(self.get_temp_dir(), "model.ckpt")
+
+ with self.test_session() as sess:
+ sess.run(tf.initialize_all_variables())
+ sess.run(assign_op)
+ saved_col_weights = col_weights.eval()
+ save.save(sess, checkpoint_path)
+
+ crossed_col_initialized = tf.contrib.layers.crossed_column(
+ columns=[sparse_col_1, sparse_col_2],
+ hash_bucket_size=4,
+ ckpt_to_load_from=checkpoint_path,
+ tensor_name_in_ckpt="run_1/col_1_X_col_2_weights")
+
+ with tf.variable_scope("run_2"):
+ # This will initialize the crossed column weights from provided checkpoint
+ # and return a [4, 1] tensor which is same as weights variable. Since we
+ # won't modify weights, this should be same as 'saved_col_weights'.
+ col_weights_from_ckpt, _ = crossed_col_initialized.to_weighted_sum(
+ input_tensor)
+
+ with self.test_session() as sess:
+ sess.run(tf.initialize_all_variables())
+ loaded_col_weights = col_weights_from_ckpt.eval()
+
+ self.assertAllClose(saved_col_weights, loaded_col_weights)
+
if __name__ == "__main__":
tf.test.main()