diff options
Diffstat (limited to 'tensorflow/python/feature_column/feature_column.py')
-rw-r--r-- | tensorflow/python/feature_column/feature_column.py | 337 |
1 files changed, 174 insertions, 163 deletions
diff --git a/tensorflow/python/feature_column/feature_column.py b/tensorflow/python/feature_column/feature_column.py index f9201a4794..0ad8131599 100644 --- a/tensorflow/python/feature_column/feature_column.py +++ b/tensorflow/python/feature_column/feature_column.py @@ -135,6 +135,7 @@ import numpy as np import six +from tensorflow.python.eager import context from tensorflow.python.framework import dtypes from tensorflow.python.framework import ops from tensorflow.python.framework import sparse_tensor as sparse_tensor_lib @@ -462,6 +463,16 @@ def linear_model(features, return predictions +def _add_to_collections(var, weight_collections): + # TODO(rohanj): Explore adding a _get_variable_list method on `Variable` + # so that we don't have to do this check. + if isinstance(var, variables.PartitionedVariable): + for constituent_var in list(var): + ops.add_to_collections(weight_collections, constituent_var) + else: + ops.add_to_collections(weight_collections, var) + + class _FCLinearWrapper(base.Layer): """Wraps a _FeatureColumn in a layer for use in a linear model. @@ -482,12 +493,8 @@ class _FCLinearWrapper(base.Layer): self._units = units self._sparse_combiner = sparse_combiner self._weight_collections = weight_collections - self._state = {} def build(self, _): - self._state = self._feature_column._create_state( # pylint: disable=protected-access - self._weight_collections, self.add_variable) - if isinstance(self._feature_column, _CategoricalColumn): weight = self.add_variable( name='weights', @@ -501,7 +508,7 @@ class _FCLinearWrapper(base.Layer): shape=[num_elements, self._units], initializer=init_ops.zeros_initializer(), trainable=self.trainable) - ops.add_to_collections(self._weight_collections, weight) + _add_to_collections(weight, self._weight_collections) self._weight_var = weight self.built = True @@ -513,8 +520,7 @@ class _FCLinearWrapper(base.Layer): sparse_combiner=self._sparse_combiner, weight_collections=self._weight_collections, trainable=self.trainable, - weight_var=self._weight_var, - state=self._state) + weight_var=self._weight_var) return weighted_sum @@ -538,7 +544,7 @@ class _BiasLayer(base.Layer): shape=[self._units], initializer=init_ops.zeros_initializer(), trainable=self.trainable) - ops.add_to_collections(self._weight_collections, self._bias_variable) + _add_to_collections(self._bias_variable, self._weight_collections) self.built = True def call(self, _): @@ -806,11 +812,22 @@ def embedding_column( initializer = init_ops.truncated_normal_initializer( mean=0.0, stddev=1 / math.sqrt(dimension)) + embedding_shape = categorical_column._num_buckets, dimension # pylint: disable=protected-access + + def _creator(weight_collections, scope): + embedding_column_layer = _EmbeddingColumnLayer( + embedding_shape=embedding_shape, + initializer=initializer, + weight_collections=weight_collections, + trainable=trainable, + name='embedding_column_layer') + return embedding_column_layer(None, scope=scope) # pylint: disable=not-callable + return _EmbeddingColumn( categorical_column=categorical_column, dimension=dimension, combiner=combiner, - initializer=initializer, + layer_creator=_creator, ckpt_to_load_from=ckpt_to_load_from, tensor_name_in_ckpt=tensor_name_in_ckpt, max_norm=max_norm, @@ -933,6 +950,7 @@ def shared_embedding_columns( sorted_columns = sorted(categorical_columns, key=lambda x: x.name) c0 = sorted_columns[0] + num_buckets = c0._num_buckets # pylint: disable=protected-access if not isinstance(c0, _CategoricalColumn): raise ValueError( 'All categorical_columns must be subclasses of _CategoricalColumn. ' @@ -948,23 +966,45 @@ def shared_embedding_columns( 'the same type, or be weighted_categorical_column of the same type. ' 'Given column: {} of type: {} does not match given column: {} of ' 'type: {}'.format(c0, type(c0), c, type(c))) + if num_buckets != c._num_buckets: # pylint: disable=protected-access + raise ValueError( + 'To use shared_embedding_column, all categorical_columns must have ' + 'the same number of buckets. Given column: {} with buckets: {} does ' + 'not match column: {} with buckets: {}'.format( + c0, num_buckets, c, c._num_buckets)) # pylint: disable=protected-access if not shared_embedding_collection_name: shared_embedding_collection_name = '_'.join(c.name for c in sorted_columns) shared_embedding_collection_name += '_shared_embedding' + # Create the state (_SharedEmbeddingColumnLayer) here. + embedding_shape = num_buckets, dimension + + shared_embedding_column_layer = _EmbeddingColumnLayer( + embedding_shape=embedding_shape, + initializer=initializer, + weight_collections=[], + trainable=trainable, + name=shared_embedding_collection_name) + result = [] for column in categorical_columns: - result.append(_SharedEmbeddingColumn( - categorical_column=column, - dimension=dimension, - combiner=combiner, - initializer=initializer, - shared_embedding_collection_name=shared_embedding_collection_name, - ckpt_to_load_from=ckpt_to_load_from, - tensor_name_in_ckpt=tensor_name_in_ckpt, - max_norm=max_norm, - trainable=trainable)) + result.append( + _SharedEmbeddingColumn( + categorical_column=column, + initializer=initializer, + dimension=dimension, + combiner=combiner, + var_scope_name=shared_embedding_collection_name, + ckpt_to_load_from=ckpt_to_load_from, + tensor_name_in_ckpt=tensor_name_in_ckpt, + max_norm=max_norm, + trainable=trainable)) + + for single_result in result: + single_result._set_layer(shared_embedding_column_layer) # pylint: disable=protected-access + single_result._set_all_columns(result) # pylint: disable=protected-access + return result @@ -1721,6 +1761,57 @@ def crossed_column(keys, hash_bucket_size, hash_key=None): hash_key=hash_key) +# TODO(rohanj): Clearly define semantics of this layer. +class _EmbeddingColumnLayer(base.Layer): + """A layer that stores all the state required for a embedding column.""" + + def __init__(self, + embedding_shape, + initializer, + weight_collections=None, + trainable=True, + name=None, + **kwargs): + """Constructor. + + Args: + embedding_shape: Shape of the embedding variable used for lookup. + initializer: A variable initializer function to be used in embedding + variable initialization. If not specified, defaults to + `tf.truncated_normal_initializer` with mean `0.0` and standard deviation + `1/sqrt(dimension)`. + weight_collections: A list of collection names to which the Variable will + be added. Note that, variables will also be added to collections + `tf.GraphKeys.GLOBAL_VARIABLES` and `ops.GraphKeys.MODEL_VARIABLES`. + trainable: If `True` also add the variable to the graph collection + `GraphKeys.TRAINABLE_VARIABLES` (see `tf.Variable`). + name: Name of the layer + **kwargs: keyword named properties. + """ + super(_EmbeddingColumnLayer, self).__init__( + trainable=trainable, name=name, **kwargs) + self._embedding_shape = embedding_shape + self._initializer = initializer + self._weight_collections = weight_collections + + def build(self, _): + self._embedding_weight_var = self.add_variable( + name='embedding_weights', + shape=self._embedding_shape, + dtype=dtypes.float32, + initializer=self._initializer, + trainable=self.trainable) + # self.add_variable already appends to GLOBAL_VARIABLES collection. + if self._weight_collections and not context.executing_eagerly(): + for weight_collection in self._weight_collections: + if weight_collection != ops.GraphKeys.GLOBAL_VARIABLES: + _add_to_collections(self._embedding_weight_var, [weight_collection]) + self.built = True + + def call(self, _): + return self._embedding_weight_var + + class _FeatureColumn(object): """Represents a feature column abstraction. @@ -1794,18 +1885,13 @@ class _FeatureColumn(object): """ pass - def _create_state(self, weight_collections=None, creator=None): - """Returns an object that captures the state of the column. + def _reset_config(self): + """Resets the configuration in the column. - Args: - weight_collections: Collections to add the variable to - creator: Variable creator method called, if provided. - - Returns: - An object that encapsulates the state of the column. Can return None. + Some feature columns e.g. embedding or shared embedding columns might + have some state that is needed to be reset sometimes. Use this method + in that scenario. """ - del weight_collections, creator # Unused - return None class _DenseColumn(_FeatureColumn): @@ -1826,11 +1912,7 @@ class _DenseColumn(_FeatureColumn): pass @abc.abstractmethod - def _get_dense_tensor(self, - inputs, - weight_collections=None, - trainable=None, - state=None): + def _get_dense_tensor(self, inputs, weight_collections=None, trainable=None): """Returns a `Tensor`. The output of this function will be used by model-builder-functions. For @@ -1848,9 +1930,6 @@ class _DenseColumn(_FeatureColumn): will be created) are added. trainable: If `True` also add variables to the graph collection `GraphKeys.TRAINABLE_VARIABLES` (see @{tf.Variable}). - state: An object encapsulating the state of the column. Columns that - create state using the _create_state method would have that state - passed in to this method. Returns: `Tensor` of shape [batch_size] + `_variable_shape`. @@ -1864,8 +1943,7 @@ def _create_weighted_sum(column, sparse_combiner, weight_collections, trainable, - weight_var=None, - state=None): + weight_var=None): """Creates a weighted sum for a dense or sparse column for linear_model.""" if isinstance(column, _CategoricalColumn): return _create_categorical_column_weighted_sum( @@ -1883,8 +1961,7 @@ def _create_weighted_sum(column, units=units, weight_collections=weight_collections, trainable=trainable, - weight_var=weight_var, - state=state) + weight_var=weight_var) def _create_dense_column_weighted_sum(column, @@ -1892,20 +1969,12 @@ def _create_dense_column_weighted_sum(column, units, weight_collections, trainable, - weight_var=None, - state=None): + weight_var=None): """Create a weighted sum of a dense column for linear_model.""" - if state is not None: - tensor = column._get_dense_tensor( # pylint: disable=protected-access - builder, - weight_collections=weight_collections, - trainable=trainable, - state=state) - else: - tensor = column._get_dense_tensor( # pylint: disable=protected-access - builder, - weight_collections=weight_collections, - trainable=trainable) + tensor = column._get_dense_tensor( # pylint: disable=protected-access + builder, + weight_collections=weight_collections, + trainable=trainable) num_elements = column._variable_shape.num_elements() # pylint: disable=protected-access batch_size = array_ops.shape(tensor)[0] tensor = array_ops.reshape(tensor, shape=(batch_size, num_elements)) @@ -2368,10 +2437,10 @@ class _BucketizedColumn(_DenseColumn, _CategoricalColumn, class _EmbeddingColumn( _DenseColumn, _SequenceDenseColumn, - collections.namedtuple('_EmbeddingColumn', ( - 'categorical_column', 'dimension', 'combiner', 'initializer', - 'ckpt_to_load_from', 'tensor_name_in_ckpt', 'max_norm', 'trainable' - ))): + collections.namedtuple( + '_EmbeddingColumn', + ('categorical_column', 'dimension', 'combiner', 'layer_creator', + 'ckpt_to_load_from', 'tensor_name_in_ckpt', 'max_norm', 'trainable'))): """See `embedding_column`.""" @property @@ -2393,33 +2462,10 @@ class _EmbeddingColumn( self._shape = tensor_shape.vector(self.dimension) return self._shape - def _create_state(self, weight_collections=None, creator=None): - variables_map = {} - embedding_shape = (self.categorical_column._num_buckets, self.dimension) # pylint: disable=protected-access - if creator is not None: - embedding_weights = creator( - name='embedding_weights', - shape=embedding_shape, - dtype=dtypes.float32, - initializer=self.initializer, - trainable=self.trainable) - ops.add_to_collections(weight_collections, embedding_weights) - else: - embedding_weights = variable_scope.get_variable( - name='embedding_weights', - shape=embedding_shape, - dtype=dtypes.float32, - initializer=self.initializer, - trainable=self.trainable, - collections=weight_collections) - variables_map['embedding_weights'] = embedding_weights - return variables_map - def _get_dense_tensor_internal(self, inputs, weight_collections=None, - trainable=None, - state=None): + trainable=None): """Private method that follows the signature of _get_dense_tensor.""" # Get sparse IDs and weights. sparse_tensors = self.categorical_column._get_sparse_tensors( # pylint: disable=protected-access @@ -2427,9 +2473,9 @@ class _EmbeddingColumn( sparse_ids = sparse_tensors.id_tensor sparse_weights = sparse_tensors.weight_tensor - if state is None: - state = self._create_state(weight_collections) - embedding_weights = state['embedding_weights'] + embedding_weights = self.layer_creator( + weight_collections=weight_collections, + scope=variable_scope.get_variable_scope()) if self.ckpt_to_load_from is not None: to_restore = embedding_weights @@ -2448,11 +2494,7 @@ class _EmbeddingColumn( name='%s_weights' % self.name, max_norm=self.max_norm) - def _get_dense_tensor(self, - inputs, - weight_collections=None, - trainable=None, - state=None): + def _get_dense_tensor(self, inputs, weight_collections=None, trainable=None): if isinstance(self.categorical_column, _SequenceCategoricalColumn): raise ValueError( 'In embedding_column: {}. ' @@ -2467,8 +2509,7 @@ class _EmbeddingColumn( return self._get_dense_tensor_internal( inputs=inputs, weight_collections=weight_collections, - trainable=trainable, - state=state) + trainable=trainable) def _get_sequence_dense_tensor( self, inputs, weight_collections=None, trainable=None): @@ -2492,13 +2533,20 @@ class _EmbeddingColumn( dense_tensor=dense_tensor, sequence_length=sequence_length) +def _get_graph_for_variable(var): + if isinstance(var, variables.PartitionedVariable): + return list(var)[0].graph + else: + return var.graph + + class _SharedEmbeddingColumn( _DenseColumn, - collections.namedtuple('_SharedEmbeddingColumn', ( - 'categorical_column', 'dimension', 'combiner', 'initializer', - 'shared_embedding_collection_name', 'ckpt_to_load_from', - 'tensor_name_in_ckpt', 'max_norm', 'trainable' - ))): + collections.namedtuple( + '_SharedEmbeddingColumn', + ('categorical_column', 'dimension', 'combiner', 'initializer', + 'var_scope_name', 'ckpt_to_load_from', 'tensor_name_in_ckpt', + 'max_norm', 'trainable'))): """See `embedding_column`.""" @property @@ -2509,7 +2557,7 @@ class _SharedEmbeddingColumn( @property def _var_scope_name(self): - return self.shared_embedding_collection_name + return self.var_scope_name @property def _parse_example_spec(self): @@ -2518,45 +2566,29 @@ class _SharedEmbeddingColumn( def _transform_feature(self, inputs): return inputs.get(self.categorical_column) + def _set_layer(self, layer): + self._layer = layer + + def _set_all_columns(self, all_columns): + self._all_columns = all_columns + + def _reset_config(self): + config = self._layer.get_config() + config['embedding_shape'] = ( + self.categorical_column._num_buckets, # pylint: disable=protected-access + self.dimension) + config['initializer'] = self.initializer + self._layer = self._layer.__class__.from_config(config) + for column in self._all_columns: + column._set_layer(self._layer) # pylint: disable=protected-access + @property def _variable_shape(self): if not hasattr(self, '_shape'): self._shape = tensor_shape.vector(self.dimension) return self._shape - def _create_state(self, weight_collections=None, creator=None): - variables_map = {} - shared_embedding_collection = ops.get_collection( - self.shared_embedding_collection_name) - if not shared_embedding_collection: - embedding_shape = (self.categorical_column._num_buckets, self.dimension) # pylint: disable=protected-access - if creator is not None: - embedding_weights = creator( - name='embedding_weights', - shape=embedding_shape, - dtype=dtypes.float32, - initializer=self.initializer, - trainable=self.trainable) - ops.add_to_collections(weight_collections, embedding_weights) - else: - embedding_weights = variable_scope.get_variable( - name='embedding_weights', - shape=embedding_shape, - dtype=dtypes.float32, - initializer=self.initializer, - trainable=self.trainable, - collections=weight_collections) - ops.add_to_collection(self.shared_embedding_collection_name, - embedding_weights) - variables_map['embedding_weights'] = embedding_weights - - return variables_map - - def _get_dense_tensor(self, - inputs, - weight_collections=None, - trainable=None, - state=None): + def _get_dense_tensor(self, inputs, weight_collections=None, trainable=None): # This method is called from a variable_scope with name _var_scope_name, # which is shared among all shared embeddings. Open a name_scope here, so # that the ops for different columns have distinct names. @@ -2567,38 +2599,17 @@ class _SharedEmbeddingColumn( sparse_ids = sparse_tensors.id_tensor sparse_weights = sparse_tensors.weight_tensor - embedding_shape = (self.categorical_column._num_buckets, self.dimension) # pylint: disable=protected-access - shared_embedding_collection = ops.get_collection( - self.shared_embedding_collection_name) - if shared_embedding_collection: - if len(shared_embedding_collection) > 1: - raise ValueError( - 'Collection {} can only contain one variable. ' - 'Suggested fix A: Choose a unique name for this collection. ' - 'Suggested fix B: Do not add any variables to this collection. ' - 'The feature_column library already adds a variable under the ' - 'hood.'.format(shared_embedding_collection)) - embedding_weights = shared_embedding_collection[0] - if embedding_weights.get_shape() != embedding_shape: - raise ValueError( - 'Shared embedding collection {} contains variable {} of ' - 'unexpected shape {}. Expected shape is {}. ' - 'Suggested fix A: Choose a unique name for this collection. ' - 'Suggested fix B: Do not add any variables to this collection. ' - 'The feature_column library already adds a variable under the ' - 'hood.'.format( - self.shared_embedding_collection_name, embedding_weights.name, - embedding_weights.get_shape(), embedding_shape)) - else: - embedding_weights = variable_scope.get_variable( - name='embedding_weights', - shape=embedding_shape, - dtype=dtypes.float32, - initializer=self.initializer, - trainable=self.trainable and trainable, - collections=weight_collections) - ops.add_to_collection( - self.shared_embedding_collection_name, embedding_weights) + embedding_weights = self._layer( + None, scope=variable_scope.get_variable_scope()) + # If we're in graph mode and this is called with a different graph, + # then we should reset. + if not context.executing_eagerly() and ( + ops.get_default_graph() != + _get_graph_for_variable(embedding_weights)): + self._reset_config() + embedding_weights = self._layer( + None, scope=variable_scope.get_variable_scope()) + if self.ckpt_to_load_from is not None: to_restore = embedding_weights if isinstance(to_restore, variables.PartitionedVariable): |