aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/python/feature_column/feature_column.py
diff options
context:
space:
mode:
Diffstat (limited to 'tensorflow/python/feature_column/feature_column.py')
-rw-r--r--tensorflow/python/feature_column/feature_column.py337
1 files changed, 174 insertions, 163 deletions
diff --git a/tensorflow/python/feature_column/feature_column.py b/tensorflow/python/feature_column/feature_column.py
index f9201a4794..0ad8131599 100644
--- a/tensorflow/python/feature_column/feature_column.py
+++ b/tensorflow/python/feature_column/feature_column.py
@@ -135,6 +135,7 @@ import numpy as np
import six
+from tensorflow.python.eager import context
from tensorflow.python.framework import dtypes
from tensorflow.python.framework import ops
from tensorflow.python.framework import sparse_tensor as sparse_tensor_lib
@@ -462,6 +463,16 @@ def linear_model(features,
return predictions
+def _add_to_collections(var, weight_collections):
+ # TODO(rohanj): Explore adding a _get_variable_list method on `Variable`
+ # so that we don't have to do this check.
+ if isinstance(var, variables.PartitionedVariable):
+ for constituent_var in list(var):
+ ops.add_to_collections(weight_collections, constituent_var)
+ else:
+ ops.add_to_collections(weight_collections, var)
+
+
class _FCLinearWrapper(base.Layer):
"""Wraps a _FeatureColumn in a layer for use in a linear model.
@@ -482,12 +493,8 @@ class _FCLinearWrapper(base.Layer):
self._units = units
self._sparse_combiner = sparse_combiner
self._weight_collections = weight_collections
- self._state = {}
def build(self, _):
- self._state = self._feature_column._create_state( # pylint: disable=protected-access
- self._weight_collections, self.add_variable)
-
if isinstance(self._feature_column, _CategoricalColumn):
weight = self.add_variable(
name='weights',
@@ -501,7 +508,7 @@ class _FCLinearWrapper(base.Layer):
shape=[num_elements, self._units],
initializer=init_ops.zeros_initializer(),
trainable=self.trainable)
- ops.add_to_collections(self._weight_collections, weight)
+ _add_to_collections(weight, self._weight_collections)
self._weight_var = weight
self.built = True
@@ -513,8 +520,7 @@ class _FCLinearWrapper(base.Layer):
sparse_combiner=self._sparse_combiner,
weight_collections=self._weight_collections,
trainable=self.trainable,
- weight_var=self._weight_var,
- state=self._state)
+ weight_var=self._weight_var)
return weighted_sum
@@ -538,7 +544,7 @@ class _BiasLayer(base.Layer):
shape=[self._units],
initializer=init_ops.zeros_initializer(),
trainable=self.trainable)
- ops.add_to_collections(self._weight_collections, self._bias_variable)
+ _add_to_collections(self._bias_variable, self._weight_collections)
self.built = True
def call(self, _):
@@ -806,11 +812,22 @@ def embedding_column(
initializer = init_ops.truncated_normal_initializer(
mean=0.0, stddev=1 / math.sqrt(dimension))
+ embedding_shape = categorical_column._num_buckets, dimension # pylint: disable=protected-access
+
+ def _creator(weight_collections, scope):
+ embedding_column_layer = _EmbeddingColumnLayer(
+ embedding_shape=embedding_shape,
+ initializer=initializer,
+ weight_collections=weight_collections,
+ trainable=trainable,
+ name='embedding_column_layer')
+ return embedding_column_layer(None, scope=scope) # pylint: disable=not-callable
+
return _EmbeddingColumn(
categorical_column=categorical_column,
dimension=dimension,
combiner=combiner,
- initializer=initializer,
+ layer_creator=_creator,
ckpt_to_load_from=ckpt_to_load_from,
tensor_name_in_ckpt=tensor_name_in_ckpt,
max_norm=max_norm,
@@ -933,6 +950,7 @@ def shared_embedding_columns(
sorted_columns = sorted(categorical_columns, key=lambda x: x.name)
c0 = sorted_columns[0]
+ num_buckets = c0._num_buckets # pylint: disable=protected-access
if not isinstance(c0, _CategoricalColumn):
raise ValueError(
'All categorical_columns must be subclasses of _CategoricalColumn. '
@@ -948,23 +966,45 @@ def shared_embedding_columns(
'the same type, or be weighted_categorical_column of the same type. '
'Given column: {} of type: {} does not match given column: {} of '
'type: {}'.format(c0, type(c0), c, type(c)))
+ if num_buckets != c._num_buckets: # pylint: disable=protected-access
+ raise ValueError(
+ 'To use shared_embedding_column, all categorical_columns must have '
+ 'the same number of buckets. Given column: {} with buckets: {} does '
+ 'not match column: {} with buckets: {}'.format(
+ c0, num_buckets, c, c._num_buckets)) # pylint: disable=protected-access
if not shared_embedding_collection_name:
shared_embedding_collection_name = '_'.join(c.name for c in sorted_columns)
shared_embedding_collection_name += '_shared_embedding'
+ # Create the state (_SharedEmbeddingColumnLayer) here.
+ embedding_shape = num_buckets, dimension
+
+ shared_embedding_column_layer = _EmbeddingColumnLayer(
+ embedding_shape=embedding_shape,
+ initializer=initializer,
+ weight_collections=[],
+ trainable=trainable,
+ name=shared_embedding_collection_name)
+
result = []
for column in categorical_columns:
- result.append(_SharedEmbeddingColumn(
- categorical_column=column,
- dimension=dimension,
- combiner=combiner,
- initializer=initializer,
- shared_embedding_collection_name=shared_embedding_collection_name,
- ckpt_to_load_from=ckpt_to_load_from,
- tensor_name_in_ckpt=tensor_name_in_ckpt,
- max_norm=max_norm,
- trainable=trainable))
+ result.append(
+ _SharedEmbeddingColumn(
+ categorical_column=column,
+ initializer=initializer,
+ dimension=dimension,
+ combiner=combiner,
+ var_scope_name=shared_embedding_collection_name,
+ ckpt_to_load_from=ckpt_to_load_from,
+ tensor_name_in_ckpt=tensor_name_in_ckpt,
+ max_norm=max_norm,
+ trainable=trainable))
+
+ for single_result in result:
+ single_result._set_layer(shared_embedding_column_layer) # pylint: disable=protected-access
+ single_result._set_all_columns(result) # pylint: disable=protected-access
+
return result
@@ -1721,6 +1761,57 @@ def crossed_column(keys, hash_bucket_size, hash_key=None):
hash_key=hash_key)
+# TODO(rohanj): Clearly define semantics of this layer.
+class _EmbeddingColumnLayer(base.Layer):
+ """A layer that stores all the state required for a embedding column."""
+
+ def __init__(self,
+ embedding_shape,
+ initializer,
+ weight_collections=None,
+ trainable=True,
+ name=None,
+ **kwargs):
+ """Constructor.
+
+ Args:
+ embedding_shape: Shape of the embedding variable used for lookup.
+ initializer: A variable initializer function to be used in embedding
+ variable initialization. If not specified, defaults to
+ `tf.truncated_normal_initializer` with mean `0.0` and standard deviation
+ `1/sqrt(dimension)`.
+ weight_collections: A list of collection names to which the Variable will
+ be added. Note that, variables will also be added to collections
+ `tf.GraphKeys.GLOBAL_VARIABLES` and `ops.GraphKeys.MODEL_VARIABLES`.
+ trainable: If `True` also add the variable to the graph collection
+ `GraphKeys.TRAINABLE_VARIABLES` (see `tf.Variable`).
+ name: Name of the layer
+ **kwargs: keyword named properties.
+ """
+ super(_EmbeddingColumnLayer, self).__init__(
+ trainable=trainable, name=name, **kwargs)
+ self._embedding_shape = embedding_shape
+ self._initializer = initializer
+ self._weight_collections = weight_collections
+
+ def build(self, _):
+ self._embedding_weight_var = self.add_variable(
+ name='embedding_weights',
+ shape=self._embedding_shape,
+ dtype=dtypes.float32,
+ initializer=self._initializer,
+ trainable=self.trainable)
+ # self.add_variable already appends to GLOBAL_VARIABLES collection.
+ if self._weight_collections and not context.executing_eagerly():
+ for weight_collection in self._weight_collections:
+ if weight_collection != ops.GraphKeys.GLOBAL_VARIABLES:
+ _add_to_collections(self._embedding_weight_var, [weight_collection])
+ self.built = True
+
+ def call(self, _):
+ return self._embedding_weight_var
+
+
class _FeatureColumn(object):
"""Represents a feature column abstraction.
@@ -1794,18 +1885,13 @@ class _FeatureColumn(object):
"""
pass
- def _create_state(self, weight_collections=None, creator=None):
- """Returns an object that captures the state of the column.
+ def _reset_config(self):
+ """Resets the configuration in the column.
- Args:
- weight_collections: Collections to add the variable to
- creator: Variable creator method called, if provided.
-
- Returns:
- An object that encapsulates the state of the column. Can return None.
+ Some feature columns e.g. embedding or shared embedding columns might
+ have some state that is needed to be reset sometimes. Use this method
+ in that scenario.
"""
- del weight_collections, creator # Unused
- return None
class _DenseColumn(_FeatureColumn):
@@ -1826,11 +1912,7 @@ class _DenseColumn(_FeatureColumn):
pass
@abc.abstractmethod
- def _get_dense_tensor(self,
- inputs,
- weight_collections=None,
- trainable=None,
- state=None):
+ def _get_dense_tensor(self, inputs, weight_collections=None, trainable=None):
"""Returns a `Tensor`.
The output of this function will be used by model-builder-functions. For
@@ -1848,9 +1930,6 @@ class _DenseColumn(_FeatureColumn):
will be created) are added.
trainable: If `True` also add variables to the graph collection
`GraphKeys.TRAINABLE_VARIABLES` (see @{tf.Variable}).
- state: An object encapsulating the state of the column. Columns that
- create state using the _create_state method would have that state
- passed in to this method.
Returns:
`Tensor` of shape [batch_size] + `_variable_shape`.
@@ -1864,8 +1943,7 @@ def _create_weighted_sum(column,
sparse_combiner,
weight_collections,
trainable,
- weight_var=None,
- state=None):
+ weight_var=None):
"""Creates a weighted sum for a dense or sparse column for linear_model."""
if isinstance(column, _CategoricalColumn):
return _create_categorical_column_weighted_sum(
@@ -1883,8 +1961,7 @@ def _create_weighted_sum(column,
units=units,
weight_collections=weight_collections,
trainable=trainable,
- weight_var=weight_var,
- state=state)
+ weight_var=weight_var)
def _create_dense_column_weighted_sum(column,
@@ -1892,20 +1969,12 @@ def _create_dense_column_weighted_sum(column,
units,
weight_collections,
trainable,
- weight_var=None,
- state=None):
+ weight_var=None):
"""Create a weighted sum of a dense column for linear_model."""
- if state is not None:
- tensor = column._get_dense_tensor( # pylint: disable=protected-access
- builder,
- weight_collections=weight_collections,
- trainable=trainable,
- state=state)
- else:
- tensor = column._get_dense_tensor( # pylint: disable=protected-access
- builder,
- weight_collections=weight_collections,
- trainable=trainable)
+ tensor = column._get_dense_tensor( # pylint: disable=protected-access
+ builder,
+ weight_collections=weight_collections,
+ trainable=trainable)
num_elements = column._variable_shape.num_elements() # pylint: disable=protected-access
batch_size = array_ops.shape(tensor)[0]
tensor = array_ops.reshape(tensor, shape=(batch_size, num_elements))
@@ -2368,10 +2437,10 @@ class _BucketizedColumn(_DenseColumn, _CategoricalColumn,
class _EmbeddingColumn(
_DenseColumn, _SequenceDenseColumn,
- collections.namedtuple('_EmbeddingColumn', (
- 'categorical_column', 'dimension', 'combiner', 'initializer',
- 'ckpt_to_load_from', 'tensor_name_in_ckpt', 'max_norm', 'trainable'
- ))):
+ collections.namedtuple(
+ '_EmbeddingColumn',
+ ('categorical_column', 'dimension', 'combiner', 'layer_creator',
+ 'ckpt_to_load_from', 'tensor_name_in_ckpt', 'max_norm', 'trainable'))):
"""See `embedding_column`."""
@property
@@ -2393,33 +2462,10 @@ class _EmbeddingColumn(
self._shape = tensor_shape.vector(self.dimension)
return self._shape
- def _create_state(self, weight_collections=None, creator=None):
- variables_map = {}
- embedding_shape = (self.categorical_column._num_buckets, self.dimension) # pylint: disable=protected-access
- if creator is not None:
- embedding_weights = creator(
- name='embedding_weights',
- shape=embedding_shape,
- dtype=dtypes.float32,
- initializer=self.initializer,
- trainable=self.trainable)
- ops.add_to_collections(weight_collections, embedding_weights)
- else:
- embedding_weights = variable_scope.get_variable(
- name='embedding_weights',
- shape=embedding_shape,
- dtype=dtypes.float32,
- initializer=self.initializer,
- trainable=self.trainable,
- collections=weight_collections)
- variables_map['embedding_weights'] = embedding_weights
- return variables_map
-
def _get_dense_tensor_internal(self,
inputs,
weight_collections=None,
- trainable=None,
- state=None):
+ trainable=None):
"""Private method that follows the signature of _get_dense_tensor."""
# Get sparse IDs and weights.
sparse_tensors = self.categorical_column._get_sparse_tensors( # pylint: disable=protected-access
@@ -2427,9 +2473,9 @@ class _EmbeddingColumn(
sparse_ids = sparse_tensors.id_tensor
sparse_weights = sparse_tensors.weight_tensor
- if state is None:
- state = self._create_state(weight_collections)
- embedding_weights = state['embedding_weights']
+ embedding_weights = self.layer_creator(
+ weight_collections=weight_collections,
+ scope=variable_scope.get_variable_scope())
if self.ckpt_to_load_from is not None:
to_restore = embedding_weights
@@ -2448,11 +2494,7 @@ class _EmbeddingColumn(
name='%s_weights' % self.name,
max_norm=self.max_norm)
- def _get_dense_tensor(self,
- inputs,
- weight_collections=None,
- trainable=None,
- state=None):
+ def _get_dense_tensor(self, inputs, weight_collections=None, trainable=None):
if isinstance(self.categorical_column, _SequenceCategoricalColumn):
raise ValueError(
'In embedding_column: {}. '
@@ -2467,8 +2509,7 @@ class _EmbeddingColumn(
return self._get_dense_tensor_internal(
inputs=inputs,
weight_collections=weight_collections,
- trainable=trainable,
- state=state)
+ trainable=trainable)
def _get_sequence_dense_tensor(
self, inputs, weight_collections=None, trainable=None):
@@ -2492,13 +2533,20 @@ class _EmbeddingColumn(
dense_tensor=dense_tensor, sequence_length=sequence_length)
+def _get_graph_for_variable(var):
+ if isinstance(var, variables.PartitionedVariable):
+ return list(var)[0].graph
+ else:
+ return var.graph
+
+
class _SharedEmbeddingColumn(
_DenseColumn,
- collections.namedtuple('_SharedEmbeddingColumn', (
- 'categorical_column', 'dimension', 'combiner', 'initializer',
- 'shared_embedding_collection_name', 'ckpt_to_load_from',
- 'tensor_name_in_ckpt', 'max_norm', 'trainable'
- ))):
+ collections.namedtuple(
+ '_SharedEmbeddingColumn',
+ ('categorical_column', 'dimension', 'combiner', 'initializer',
+ 'var_scope_name', 'ckpt_to_load_from', 'tensor_name_in_ckpt',
+ 'max_norm', 'trainable'))):
"""See `embedding_column`."""
@property
@@ -2509,7 +2557,7 @@ class _SharedEmbeddingColumn(
@property
def _var_scope_name(self):
- return self.shared_embedding_collection_name
+ return self.var_scope_name
@property
def _parse_example_spec(self):
@@ -2518,45 +2566,29 @@ class _SharedEmbeddingColumn(
def _transform_feature(self, inputs):
return inputs.get(self.categorical_column)
+ def _set_layer(self, layer):
+ self._layer = layer
+
+ def _set_all_columns(self, all_columns):
+ self._all_columns = all_columns
+
+ def _reset_config(self):
+ config = self._layer.get_config()
+ config['embedding_shape'] = (
+ self.categorical_column._num_buckets, # pylint: disable=protected-access
+ self.dimension)
+ config['initializer'] = self.initializer
+ self._layer = self._layer.__class__.from_config(config)
+ for column in self._all_columns:
+ column._set_layer(self._layer) # pylint: disable=protected-access
+
@property
def _variable_shape(self):
if not hasattr(self, '_shape'):
self._shape = tensor_shape.vector(self.dimension)
return self._shape
- def _create_state(self, weight_collections=None, creator=None):
- variables_map = {}
- shared_embedding_collection = ops.get_collection(
- self.shared_embedding_collection_name)
- if not shared_embedding_collection:
- embedding_shape = (self.categorical_column._num_buckets, self.dimension) # pylint: disable=protected-access
- if creator is not None:
- embedding_weights = creator(
- name='embedding_weights',
- shape=embedding_shape,
- dtype=dtypes.float32,
- initializer=self.initializer,
- trainable=self.trainable)
- ops.add_to_collections(weight_collections, embedding_weights)
- else:
- embedding_weights = variable_scope.get_variable(
- name='embedding_weights',
- shape=embedding_shape,
- dtype=dtypes.float32,
- initializer=self.initializer,
- trainable=self.trainable,
- collections=weight_collections)
- ops.add_to_collection(self.shared_embedding_collection_name,
- embedding_weights)
- variables_map['embedding_weights'] = embedding_weights
-
- return variables_map
-
- def _get_dense_tensor(self,
- inputs,
- weight_collections=None,
- trainable=None,
- state=None):
+ def _get_dense_tensor(self, inputs, weight_collections=None, trainable=None):
# This method is called from a variable_scope with name _var_scope_name,
# which is shared among all shared embeddings. Open a name_scope here, so
# that the ops for different columns have distinct names.
@@ -2567,38 +2599,17 @@ class _SharedEmbeddingColumn(
sparse_ids = sparse_tensors.id_tensor
sparse_weights = sparse_tensors.weight_tensor
- embedding_shape = (self.categorical_column._num_buckets, self.dimension) # pylint: disable=protected-access
- shared_embedding_collection = ops.get_collection(
- self.shared_embedding_collection_name)
- if shared_embedding_collection:
- if len(shared_embedding_collection) > 1:
- raise ValueError(
- 'Collection {} can only contain one variable. '
- 'Suggested fix A: Choose a unique name for this collection. '
- 'Suggested fix B: Do not add any variables to this collection. '
- 'The feature_column library already adds a variable under the '
- 'hood.'.format(shared_embedding_collection))
- embedding_weights = shared_embedding_collection[0]
- if embedding_weights.get_shape() != embedding_shape:
- raise ValueError(
- 'Shared embedding collection {} contains variable {} of '
- 'unexpected shape {}. Expected shape is {}. '
- 'Suggested fix A: Choose a unique name for this collection. '
- 'Suggested fix B: Do not add any variables to this collection. '
- 'The feature_column library already adds a variable under the '
- 'hood.'.format(
- self.shared_embedding_collection_name, embedding_weights.name,
- embedding_weights.get_shape(), embedding_shape))
- else:
- embedding_weights = variable_scope.get_variable(
- name='embedding_weights',
- shape=embedding_shape,
- dtype=dtypes.float32,
- initializer=self.initializer,
- trainable=self.trainable and trainable,
- collections=weight_collections)
- ops.add_to_collection(
- self.shared_embedding_collection_name, embedding_weights)
+ embedding_weights = self._layer(
+ None, scope=variable_scope.get_variable_scope())
+ # If we're in graph mode and this is called with a different graph,
+ # then we should reset.
+ if not context.executing_eagerly() and (
+ ops.get_default_graph() !=
+ _get_graph_for_variable(embedding_weights)):
+ self._reset_config()
+ embedding_weights = self._layer(
+ None, scope=variable_scope.get_variable_scope())
+
if self.ckpt_to_load_from is not None:
to_restore = embedding_weights
if isinstance(to_restore, variables.PartitionedVariable):