aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/python/feature_column
diff options
context:
space:
mode:
authorGravatar Rohan Jain <rohanj@google.com>2018-04-18 12:51:56 -0700
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2018-04-18 12:55:16 -0700
commit8b1c3049028d1c25d7f4acc3af794918d64aafdf (patch)
tree2adbb863c1d5f5cb446e49e6cad943bdc634c017 /tensorflow/python/feature_column
parent011740b18b8309bb3126f95b736931d850a83861 (diff)
Moving all state (variables) required for _EmbeddingColumn and _SharedEmbeddingColumn into a base.Layer
PiperOrigin-RevId: 193401873
Diffstat (limited to 'tensorflow/python/feature_column')
-rw-r--r--tensorflow/python/feature_column/feature_column.py337
-rw-r--r--tensorflow/python/feature_column/feature_column_test.py280
2 files changed, 293 insertions, 324 deletions
diff --git a/tensorflow/python/feature_column/feature_column.py b/tensorflow/python/feature_column/feature_column.py
index f9201a4794..0ad8131599 100644
--- a/tensorflow/python/feature_column/feature_column.py
+++ b/tensorflow/python/feature_column/feature_column.py
@@ -135,6 +135,7 @@ import numpy as np
import six
+from tensorflow.python.eager import context
from tensorflow.python.framework import dtypes
from tensorflow.python.framework import ops
from tensorflow.python.framework import sparse_tensor as sparse_tensor_lib
@@ -462,6 +463,16 @@ def linear_model(features,
return predictions
+def _add_to_collections(var, weight_collections):
+ # TODO(rohanj): Explore adding a _get_variable_list method on `Variable`
+ # so that we don't have to do this check.
+ if isinstance(var, variables.PartitionedVariable):
+ for constituent_var in list(var):
+ ops.add_to_collections(weight_collections, constituent_var)
+ else:
+ ops.add_to_collections(weight_collections, var)
+
+
class _FCLinearWrapper(base.Layer):
"""Wraps a _FeatureColumn in a layer for use in a linear model.
@@ -482,12 +493,8 @@ class _FCLinearWrapper(base.Layer):
self._units = units
self._sparse_combiner = sparse_combiner
self._weight_collections = weight_collections
- self._state = {}
def build(self, _):
- self._state = self._feature_column._create_state( # pylint: disable=protected-access
- self._weight_collections, self.add_variable)
-
if isinstance(self._feature_column, _CategoricalColumn):
weight = self.add_variable(
name='weights',
@@ -501,7 +508,7 @@ class _FCLinearWrapper(base.Layer):
shape=[num_elements, self._units],
initializer=init_ops.zeros_initializer(),
trainable=self.trainable)
- ops.add_to_collections(self._weight_collections, weight)
+ _add_to_collections(weight, self._weight_collections)
self._weight_var = weight
self.built = True
@@ -513,8 +520,7 @@ class _FCLinearWrapper(base.Layer):
sparse_combiner=self._sparse_combiner,
weight_collections=self._weight_collections,
trainable=self.trainable,
- weight_var=self._weight_var,
- state=self._state)
+ weight_var=self._weight_var)
return weighted_sum
@@ -538,7 +544,7 @@ class _BiasLayer(base.Layer):
shape=[self._units],
initializer=init_ops.zeros_initializer(),
trainable=self.trainable)
- ops.add_to_collections(self._weight_collections, self._bias_variable)
+ _add_to_collections(self._bias_variable, self._weight_collections)
self.built = True
def call(self, _):
@@ -806,11 +812,22 @@ def embedding_column(
initializer = init_ops.truncated_normal_initializer(
mean=0.0, stddev=1 / math.sqrt(dimension))
+ embedding_shape = categorical_column._num_buckets, dimension # pylint: disable=protected-access
+
+ def _creator(weight_collections, scope):
+ embedding_column_layer = _EmbeddingColumnLayer(
+ embedding_shape=embedding_shape,
+ initializer=initializer,
+ weight_collections=weight_collections,
+ trainable=trainable,
+ name='embedding_column_layer')
+ return embedding_column_layer(None, scope=scope) # pylint: disable=not-callable
+
return _EmbeddingColumn(
categorical_column=categorical_column,
dimension=dimension,
combiner=combiner,
- initializer=initializer,
+ layer_creator=_creator,
ckpt_to_load_from=ckpt_to_load_from,
tensor_name_in_ckpt=tensor_name_in_ckpt,
max_norm=max_norm,
@@ -933,6 +950,7 @@ def shared_embedding_columns(
sorted_columns = sorted(categorical_columns, key=lambda x: x.name)
c0 = sorted_columns[0]
+ num_buckets = c0._num_buckets # pylint: disable=protected-access
if not isinstance(c0, _CategoricalColumn):
raise ValueError(
'All categorical_columns must be subclasses of _CategoricalColumn. '
@@ -948,23 +966,45 @@ def shared_embedding_columns(
'the same type, or be weighted_categorical_column of the same type. '
'Given column: {} of type: {} does not match given column: {} of '
'type: {}'.format(c0, type(c0), c, type(c)))
+ if num_buckets != c._num_buckets: # pylint: disable=protected-access
+ raise ValueError(
+ 'To use shared_embedding_column, all categorical_columns must have '
+ 'the same number of buckets. Given column: {} with buckets: {} does '
+ 'not match column: {} with buckets: {}'.format(
+ c0, num_buckets, c, c._num_buckets)) # pylint: disable=protected-access
if not shared_embedding_collection_name:
shared_embedding_collection_name = '_'.join(c.name for c in sorted_columns)
shared_embedding_collection_name += '_shared_embedding'
+ # Create the state (_SharedEmbeddingColumnLayer) here.
+ embedding_shape = num_buckets, dimension
+
+ shared_embedding_column_layer = _EmbeddingColumnLayer(
+ embedding_shape=embedding_shape,
+ initializer=initializer,
+ weight_collections=[],
+ trainable=trainable,
+ name=shared_embedding_collection_name)
+
result = []
for column in categorical_columns:
- result.append(_SharedEmbeddingColumn(
- categorical_column=column,
- dimension=dimension,
- combiner=combiner,
- initializer=initializer,
- shared_embedding_collection_name=shared_embedding_collection_name,
- ckpt_to_load_from=ckpt_to_load_from,
- tensor_name_in_ckpt=tensor_name_in_ckpt,
- max_norm=max_norm,
- trainable=trainable))
+ result.append(
+ _SharedEmbeddingColumn(
+ categorical_column=column,
+ initializer=initializer,
+ dimension=dimension,
+ combiner=combiner,
+ var_scope_name=shared_embedding_collection_name,
+ ckpt_to_load_from=ckpt_to_load_from,
+ tensor_name_in_ckpt=tensor_name_in_ckpt,
+ max_norm=max_norm,
+ trainable=trainable))
+
+ for single_result in result:
+ single_result._set_layer(shared_embedding_column_layer) # pylint: disable=protected-access
+ single_result._set_all_columns(result) # pylint: disable=protected-access
+
return result
@@ -1721,6 +1761,57 @@ def crossed_column(keys, hash_bucket_size, hash_key=None):
hash_key=hash_key)
+# TODO(rohanj): Clearly define semantics of this layer.
+class _EmbeddingColumnLayer(base.Layer):
+ """A layer that stores all the state required for a embedding column."""
+
+ def __init__(self,
+ embedding_shape,
+ initializer,
+ weight_collections=None,
+ trainable=True,
+ name=None,
+ **kwargs):
+ """Constructor.
+
+ Args:
+ embedding_shape: Shape of the embedding variable used for lookup.
+ initializer: A variable initializer function to be used in embedding
+ variable initialization. If not specified, defaults to
+ `tf.truncated_normal_initializer` with mean `0.0` and standard deviation
+ `1/sqrt(dimension)`.
+ weight_collections: A list of collection names to which the Variable will
+ be added. Note that, variables will also be added to collections
+ `tf.GraphKeys.GLOBAL_VARIABLES` and `ops.GraphKeys.MODEL_VARIABLES`.
+ trainable: If `True` also add the variable to the graph collection
+ `GraphKeys.TRAINABLE_VARIABLES` (see `tf.Variable`).
+ name: Name of the layer
+ **kwargs: keyword named properties.
+ """
+ super(_EmbeddingColumnLayer, self).__init__(
+ trainable=trainable, name=name, **kwargs)
+ self._embedding_shape = embedding_shape
+ self._initializer = initializer
+ self._weight_collections = weight_collections
+
+ def build(self, _):
+ self._embedding_weight_var = self.add_variable(
+ name='embedding_weights',
+ shape=self._embedding_shape,
+ dtype=dtypes.float32,
+ initializer=self._initializer,
+ trainable=self.trainable)
+ # self.add_variable already appends to GLOBAL_VARIABLES collection.
+ if self._weight_collections and not context.executing_eagerly():
+ for weight_collection in self._weight_collections:
+ if weight_collection != ops.GraphKeys.GLOBAL_VARIABLES:
+ _add_to_collections(self._embedding_weight_var, [weight_collection])
+ self.built = True
+
+ def call(self, _):
+ return self._embedding_weight_var
+
+
class _FeatureColumn(object):
"""Represents a feature column abstraction.
@@ -1794,18 +1885,13 @@ class _FeatureColumn(object):
"""
pass
- def _create_state(self, weight_collections=None, creator=None):
- """Returns an object that captures the state of the column.
+ def _reset_config(self):
+ """Resets the configuration in the column.
- Args:
- weight_collections: Collections to add the variable to
- creator: Variable creator method called, if provided.
-
- Returns:
- An object that encapsulates the state of the column. Can return None.
+ Some feature columns e.g. embedding or shared embedding columns might
+ have some state that is needed to be reset sometimes. Use this method
+ in that scenario.
"""
- del weight_collections, creator # Unused
- return None
class _DenseColumn(_FeatureColumn):
@@ -1826,11 +1912,7 @@ class _DenseColumn(_FeatureColumn):
pass
@abc.abstractmethod
- def _get_dense_tensor(self,
- inputs,
- weight_collections=None,
- trainable=None,
- state=None):
+ def _get_dense_tensor(self, inputs, weight_collections=None, trainable=None):
"""Returns a `Tensor`.
The output of this function will be used by model-builder-functions. For
@@ -1848,9 +1930,6 @@ class _DenseColumn(_FeatureColumn):
will be created) are added.
trainable: If `True` also add variables to the graph collection
`GraphKeys.TRAINABLE_VARIABLES` (see @{tf.Variable}).
- state: An object encapsulating the state of the column. Columns that
- create state using the _create_state method would have that state
- passed in to this method.
Returns:
`Tensor` of shape [batch_size] + `_variable_shape`.
@@ -1864,8 +1943,7 @@ def _create_weighted_sum(column,
sparse_combiner,
weight_collections,
trainable,
- weight_var=None,
- state=None):
+ weight_var=None):
"""Creates a weighted sum for a dense or sparse column for linear_model."""
if isinstance(column, _CategoricalColumn):
return _create_categorical_column_weighted_sum(
@@ -1883,8 +1961,7 @@ def _create_weighted_sum(column,
units=units,
weight_collections=weight_collections,
trainable=trainable,
- weight_var=weight_var,
- state=state)
+ weight_var=weight_var)
def _create_dense_column_weighted_sum(column,
@@ -1892,20 +1969,12 @@ def _create_dense_column_weighted_sum(column,
units,
weight_collections,
trainable,
- weight_var=None,
- state=None):
+ weight_var=None):
"""Create a weighted sum of a dense column for linear_model."""
- if state is not None:
- tensor = column._get_dense_tensor( # pylint: disable=protected-access
- builder,
- weight_collections=weight_collections,
- trainable=trainable,
- state=state)
- else:
- tensor = column._get_dense_tensor( # pylint: disable=protected-access
- builder,
- weight_collections=weight_collections,
- trainable=trainable)
+ tensor = column._get_dense_tensor( # pylint: disable=protected-access
+ builder,
+ weight_collections=weight_collections,
+ trainable=trainable)
num_elements = column._variable_shape.num_elements() # pylint: disable=protected-access
batch_size = array_ops.shape(tensor)[0]
tensor = array_ops.reshape(tensor, shape=(batch_size, num_elements))
@@ -2368,10 +2437,10 @@ class _BucketizedColumn(_DenseColumn, _CategoricalColumn,
class _EmbeddingColumn(
_DenseColumn, _SequenceDenseColumn,
- collections.namedtuple('_EmbeddingColumn', (
- 'categorical_column', 'dimension', 'combiner', 'initializer',
- 'ckpt_to_load_from', 'tensor_name_in_ckpt', 'max_norm', 'trainable'
- ))):
+ collections.namedtuple(
+ '_EmbeddingColumn',
+ ('categorical_column', 'dimension', 'combiner', 'layer_creator',
+ 'ckpt_to_load_from', 'tensor_name_in_ckpt', 'max_norm', 'trainable'))):
"""See `embedding_column`."""
@property
@@ -2393,33 +2462,10 @@ class _EmbeddingColumn(
self._shape = tensor_shape.vector(self.dimension)
return self._shape
- def _create_state(self, weight_collections=None, creator=None):
- variables_map = {}
- embedding_shape = (self.categorical_column._num_buckets, self.dimension) # pylint: disable=protected-access
- if creator is not None:
- embedding_weights = creator(
- name='embedding_weights',
- shape=embedding_shape,
- dtype=dtypes.float32,
- initializer=self.initializer,
- trainable=self.trainable)
- ops.add_to_collections(weight_collections, embedding_weights)
- else:
- embedding_weights = variable_scope.get_variable(
- name='embedding_weights',
- shape=embedding_shape,
- dtype=dtypes.float32,
- initializer=self.initializer,
- trainable=self.trainable,
- collections=weight_collections)
- variables_map['embedding_weights'] = embedding_weights
- return variables_map
-
def _get_dense_tensor_internal(self,
inputs,
weight_collections=None,
- trainable=None,
- state=None):
+ trainable=None):
"""Private method that follows the signature of _get_dense_tensor."""
# Get sparse IDs and weights.
sparse_tensors = self.categorical_column._get_sparse_tensors( # pylint: disable=protected-access
@@ -2427,9 +2473,9 @@ class _EmbeddingColumn(
sparse_ids = sparse_tensors.id_tensor
sparse_weights = sparse_tensors.weight_tensor
- if state is None:
- state = self._create_state(weight_collections)
- embedding_weights = state['embedding_weights']
+ embedding_weights = self.layer_creator(
+ weight_collections=weight_collections,
+ scope=variable_scope.get_variable_scope())
if self.ckpt_to_load_from is not None:
to_restore = embedding_weights
@@ -2448,11 +2494,7 @@ class _EmbeddingColumn(
name='%s_weights' % self.name,
max_norm=self.max_norm)
- def _get_dense_tensor(self,
- inputs,
- weight_collections=None,
- trainable=None,
- state=None):
+ def _get_dense_tensor(self, inputs, weight_collections=None, trainable=None):
if isinstance(self.categorical_column, _SequenceCategoricalColumn):
raise ValueError(
'In embedding_column: {}. '
@@ -2467,8 +2509,7 @@ class _EmbeddingColumn(
return self._get_dense_tensor_internal(
inputs=inputs,
weight_collections=weight_collections,
- trainable=trainable,
- state=state)
+ trainable=trainable)
def _get_sequence_dense_tensor(
self, inputs, weight_collections=None, trainable=None):
@@ -2492,13 +2533,20 @@ class _EmbeddingColumn(
dense_tensor=dense_tensor, sequence_length=sequence_length)
+def _get_graph_for_variable(var):
+ if isinstance(var, variables.PartitionedVariable):
+ return list(var)[0].graph
+ else:
+ return var.graph
+
+
class _SharedEmbeddingColumn(
_DenseColumn,
- collections.namedtuple('_SharedEmbeddingColumn', (
- 'categorical_column', 'dimension', 'combiner', 'initializer',
- 'shared_embedding_collection_name', 'ckpt_to_load_from',
- 'tensor_name_in_ckpt', 'max_norm', 'trainable'
- ))):
+ collections.namedtuple(
+ '_SharedEmbeddingColumn',
+ ('categorical_column', 'dimension', 'combiner', 'initializer',
+ 'var_scope_name', 'ckpt_to_load_from', 'tensor_name_in_ckpt',
+ 'max_norm', 'trainable'))):
"""See `embedding_column`."""
@property
@@ -2509,7 +2557,7 @@ class _SharedEmbeddingColumn(
@property
def _var_scope_name(self):
- return self.shared_embedding_collection_name
+ return self.var_scope_name
@property
def _parse_example_spec(self):
@@ -2518,45 +2566,29 @@ class _SharedEmbeddingColumn(
def _transform_feature(self, inputs):
return inputs.get(self.categorical_column)
+ def _set_layer(self, layer):
+ self._layer = layer
+
+ def _set_all_columns(self, all_columns):
+ self._all_columns = all_columns
+
+ def _reset_config(self):
+ config = self._layer.get_config()
+ config['embedding_shape'] = (
+ self.categorical_column._num_buckets, # pylint: disable=protected-access
+ self.dimension)
+ config['initializer'] = self.initializer
+ self._layer = self._layer.__class__.from_config(config)
+ for column in self._all_columns:
+ column._set_layer(self._layer) # pylint: disable=protected-access
+
@property
def _variable_shape(self):
if not hasattr(self, '_shape'):
self._shape = tensor_shape.vector(self.dimension)
return self._shape
- def _create_state(self, weight_collections=None, creator=None):
- variables_map = {}
- shared_embedding_collection = ops.get_collection(
- self.shared_embedding_collection_name)
- if not shared_embedding_collection:
- embedding_shape = (self.categorical_column._num_buckets, self.dimension) # pylint: disable=protected-access
- if creator is not None:
- embedding_weights = creator(
- name='embedding_weights',
- shape=embedding_shape,
- dtype=dtypes.float32,
- initializer=self.initializer,
- trainable=self.trainable)
- ops.add_to_collections(weight_collections, embedding_weights)
- else:
- embedding_weights = variable_scope.get_variable(
- name='embedding_weights',
- shape=embedding_shape,
- dtype=dtypes.float32,
- initializer=self.initializer,
- trainable=self.trainable,
- collections=weight_collections)
- ops.add_to_collection(self.shared_embedding_collection_name,
- embedding_weights)
- variables_map['embedding_weights'] = embedding_weights
-
- return variables_map
-
- def _get_dense_tensor(self,
- inputs,
- weight_collections=None,
- trainable=None,
- state=None):
+ def _get_dense_tensor(self, inputs, weight_collections=None, trainable=None):
# This method is called from a variable_scope with name _var_scope_name,
# which is shared among all shared embeddings. Open a name_scope here, so
# that the ops for different columns have distinct names.
@@ -2567,38 +2599,17 @@ class _SharedEmbeddingColumn(
sparse_ids = sparse_tensors.id_tensor
sparse_weights = sparse_tensors.weight_tensor
- embedding_shape = (self.categorical_column._num_buckets, self.dimension) # pylint: disable=protected-access
- shared_embedding_collection = ops.get_collection(
- self.shared_embedding_collection_name)
- if shared_embedding_collection:
- if len(shared_embedding_collection) > 1:
- raise ValueError(
- 'Collection {} can only contain one variable. '
- 'Suggested fix A: Choose a unique name for this collection. '
- 'Suggested fix B: Do not add any variables to this collection. '
- 'The feature_column library already adds a variable under the '
- 'hood.'.format(shared_embedding_collection))
- embedding_weights = shared_embedding_collection[0]
- if embedding_weights.get_shape() != embedding_shape:
- raise ValueError(
- 'Shared embedding collection {} contains variable {} of '
- 'unexpected shape {}. Expected shape is {}. '
- 'Suggested fix A: Choose a unique name for this collection. '
- 'Suggested fix B: Do not add any variables to this collection. '
- 'The feature_column library already adds a variable under the '
- 'hood.'.format(
- self.shared_embedding_collection_name, embedding_weights.name,
- embedding_weights.get_shape(), embedding_shape))
- else:
- embedding_weights = variable_scope.get_variable(
- name='embedding_weights',
- shape=embedding_shape,
- dtype=dtypes.float32,
- initializer=self.initializer,
- trainable=self.trainable and trainable,
- collections=weight_collections)
- ops.add_to_collection(
- self.shared_embedding_collection_name, embedding_weights)
+ embedding_weights = self._layer(
+ None, scope=variable_scope.get_variable_scope())
+ # If we're in graph mode and this is called with a different graph,
+ # then we should reset.
+ if not context.executing_eagerly() and (
+ ops.get_default_graph() !=
+ _get_graph_for_variable(embedding_weights)):
+ self._reset_config()
+ embedding_weights = self._layer(
+ None, scope=variable_scope.get_variable_scope())
+
if self.ckpt_to_load_from is not None:
to_restore = embedding_weights
if isinstance(to_restore, variables.PartitionedVariable):
diff --git a/tensorflow/python/feature_column/feature_column_test.py b/tensorflow/python/feature_column/feature_column_test.py
index 62718db0e5..46404abadc 100644
--- a/tensorflow/python/feature_column/feature_column_test.py
+++ b/tensorflow/python/feature_column/feature_column_test.py
@@ -2885,6 +2885,114 @@ class FunctionalInputLayerTest(test.TestCase):
features['price2']: [[1.], [5.]],
})
+ def test_multiple_layers_with_same_embedding_column(self):
+ some_sparse_column = fc.categorical_column_with_hash_bucket(
+ 'sparse_feature', hash_bucket_size=5)
+ some_embedding_column = fc.embedding_column(
+ some_sparse_column, dimension=10)
+
+ with ops.Graph().as_default():
+ features = {
+ 'sparse_feature': [['a'], ['x']],
+ }
+ all_cols = [some_embedding_column]
+ fc.input_layer(features, all_cols)
+ fc.input_layer(features, all_cols)
+ # Make sure that 2 variables get created in this case.
+ self.assertEqual(2, len(
+ ops.get_collection(ops.GraphKeys.GLOBAL_VARIABLES)))
+ expected_var_names = [
+ 'input_layer/sparse_feature_embedding/embedding_weights:0',
+ 'input_layer_1/sparse_feature_embedding/embedding_weights:0'
+ ]
+ self.assertItemsEqual(
+ expected_var_names,
+ [v.name for v in ops.get_collection(ops.GraphKeys.GLOBAL_VARIABLES)])
+
+ def test_multiple_layers_with_same_shared_embedding_column(self):
+ categorical_column_a = fc.categorical_column_with_identity(
+ key='aaa', num_buckets=3)
+ categorical_column_b = fc.categorical_column_with_identity(
+ key='bbb', num_buckets=3)
+ embedding_dimension = 2
+ embedding_column_b, embedding_column_a = fc.shared_embedding_columns(
+ [categorical_column_b, categorical_column_a],
+ dimension=embedding_dimension)
+
+ with ops.Graph().as_default():
+ features = {
+ 'aaa':
+ sparse_tensor.SparseTensor(
+ indices=((0, 0), (1, 0), (1, 1)),
+ values=(0, 1, 0),
+ dense_shape=(2, 2)),
+ 'bbb':
+ sparse_tensor.SparseTensor(
+ indices=((0, 0), (1, 0), (1, 1)),
+ values=(1, 2, 1),
+ dense_shape=(2, 2)),
+ }
+ all_cols = [embedding_column_a, embedding_column_b]
+ fc.input_layer(features, all_cols)
+ fc.input_layer(features, all_cols)
+ # Make sure that only 1 variable gets created in this case.
+ self.assertEqual(1, len(
+ ops.get_collection(ops.GraphKeys.GLOBAL_VARIABLES)))
+ self.assertItemsEqual(
+ ['input_layer/aaa_bbb_shared_embedding/embedding_weights:0'],
+ [v.name for v in ops.get_collection(ops.GraphKeys.GLOBAL_VARIABLES)])
+
+ def test_multiple_layers_with_same_shared_embedding_column_diff_graphs(self):
+ categorical_column_a = fc.categorical_column_with_identity(
+ key='aaa', num_buckets=3)
+ categorical_column_b = fc.categorical_column_with_identity(
+ key='bbb', num_buckets=3)
+ embedding_dimension = 2
+ embedding_column_b, embedding_column_a = fc.shared_embedding_columns(
+ [categorical_column_b, categorical_column_a],
+ dimension=embedding_dimension)
+ all_cols = [embedding_column_a, embedding_column_b]
+
+ with ops.Graph().as_default():
+ features = {
+ 'aaa':
+ sparse_tensor.SparseTensor(
+ indices=((0, 0), (1, 0), (1, 1)),
+ values=(0, 1, 0),
+ dense_shape=(2, 2)),
+ 'bbb':
+ sparse_tensor.SparseTensor(
+ indices=((0, 0), (1, 0), (1, 1)),
+ values=(1, 2, 1),
+ dense_shape=(2, 2)),
+ }
+ fc.input_layer(features, all_cols)
+ # Make sure that only 1 variable gets created in this case.
+ self.assertEqual(1, len(
+ ops.get_collection(ops.GraphKeys.GLOBAL_VARIABLES)))
+
+ with ops.Graph().as_default():
+ features1 = {
+ 'aaa':
+ sparse_tensor.SparseTensor(
+ indices=((0, 0), (1, 0), (1, 1)),
+ values=(0, 1, 0),
+ dense_shape=(2, 2)),
+ 'bbb':
+ sparse_tensor.SparseTensor(
+ indices=((0, 0), (1, 0), (1, 1)),
+ values=(1, 2, 1),
+ dense_shape=(2, 2)),
+ }
+
+ fc.input_layer(features1, all_cols)
+ # Make sure that only 1 variable gets created in this case.
+ self.assertEqual(1, len(
+ ops.get_collection(ops.GraphKeys.GLOBAL_VARIABLES)))
+ self.assertItemsEqual(
+ ['input_layer/aaa_bbb_shared_embedding/embedding_weights:0'],
+ [v.name for v in ops.get_collection(ops.GraphKeys.GLOBAL_VARIABLES)])
+
def test_with_numpy_input_fn(self):
embedding_values = (
(1., 2., 3., 4., 5.), # id 0
@@ -4504,7 +4612,6 @@ class EmbeddingColumnTest(test.TestCase):
self.assertIs(categorical_column, embedding_column.categorical_column)
self.assertEqual(embedding_dimension, embedding_column.dimension)
self.assertEqual('mean', embedding_column.combiner)
- self.assertIsNotNone(embedding_column.initializer)
self.assertIsNone(embedding_column.ckpt_to_load_from)
self.assertIsNone(embedding_column.tensor_name_in_ckpt)
self.assertIsNone(embedding_column.max_norm)
@@ -4529,7 +4636,6 @@ class EmbeddingColumnTest(test.TestCase):
self.assertIs(categorical_column, embedding_column.categorical_column)
self.assertEqual(embedding_dimension, embedding_column.dimension)
self.assertEqual('my_combiner', embedding_column.combiner)
- self.assertEqual('my_initializer', embedding_column.initializer())
self.assertEqual('my_ckpt', embedding_column.ckpt_to_load_from)
self.assertEqual('my_ckpt_tensor', embedding_column.tensor_name_in_ckpt)
self.assertEqual(42., embedding_column.max_norm)
@@ -4560,7 +4666,6 @@ class EmbeddingColumnTest(test.TestCase):
self.assertEqual(embedding_dimension, embedding_column.dimension)
self.assertEqual('my_combiner', embedding_column.combiner)
- self.assertEqual('my_initializer', embedding_column.initializer())
self.assertEqual('my_ckpt', embedding_column.ckpt_to_load_from)
self.assertEqual('my_ckpt_tensor', embedding_column.tensor_name_in_ckpt)
self.assertEqual(42., embedding_column.max_norm)
@@ -4675,72 +4780,6 @@ class EmbeddingColumnTest(test.TestCase):
self.assertAllEqual(embedding_values, global_vars[0].eval())
self.assertAllEqual(expected_lookups, embedding_lookup.eval())
- def test_get_dense_tensor_with_state(self):
- # Inputs.
- vocabulary_size = 3
- sparse_input = sparse_tensor.SparseTensorValue(
- # example 0, ids [2]
- # example 1, ids [0, 1]
- # example 2, ids []
- # example 3, ids [1]
- indices=((0, 0), (1, 0), (1, 4), (3, 0)),
- values=(2, 0, 1, 1),
- dense_shape=(4, 5))
-
- # Embedding variable.
- embedding_dimension = 2
- embedding_values = (
- (1., 2.), # id 0
- (3., 5.), # id 1
- (7., 11.) # id 2
- )
-
- def _initializer(shape, dtype, partition_info):
- self.assertAllEqual((vocabulary_size, embedding_dimension), shape)
- self.assertEqual(dtypes.float32, dtype)
- self.assertIsNone(partition_info)
- return embedding_values
-
- # Expected lookup result, using combiner='mean'.
- expected_lookups = (
- # example 0, ids [2], embedding = [7, 11]
- (7., 11.),
- # example 1, ids [0, 1], embedding = mean([1, 2] + [3, 5]) = [2, 3.5]
- (2., 3.5),
- # example 2, ids [], embedding = [0, 0]
- (0., 0.),
- # example 3, ids [1], embedding = [3, 5]
- (3., 5.),
- )
-
- # Build columns.
- categorical_column = fc.categorical_column_with_identity(
- key='aaa', num_buckets=vocabulary_size)
- embedding_column = fc.embedding_column(
- categorical_column,
- dimension=embedding_dimension,
- initializer=_initializer)
-
- # Create embedding_weights variable.
- weight_collections = [
- ops.GraphKeys.GLOBAL_VARIABLES, ops.GraphKeys.MODEL_VARIABLES
- ]
- state = embedding_column._create_state(weight_collections)
-
- # Provide sparse input and get dense result.
- embedding_lookup = embedding_column._get_dense_tensor(
- _LazyBuilder({
- 'aaa': sparse_input
- }), state=state)
-
- # Assert expected embedding variable and lookups.
- global_vars = ops.get_collection(ops.GraphKeys.GLOBAL_VARIABLES)
- self.assertItemsEqual(
- ('embedding_weights:0',), tuple([v.name for v in global_vars]))
- with _initialized_session():
- self.assertAllEqual(embedding_values, global_vars[0].eval())
- self.assertAllEqual(expected_lookups, embedding_lookup.eval())
-
def test_get_dense_tensor_3d(self):
# Inputs.
vocabulary_size = 4
@@ -4795,8 +4834,8 @@ class EmbeddingColumnTest(test.TestCase):
# Assert expected embedding variable and lookups.
global_vars = ops.get_collection(ops.GraphKeys.GLOBAL_VARIABLES)
- self.assertItemsEqual(
- ('embedding_weights:0',), tuple([v.name for v in global_vars]))
+ self.assertItemsEqual(('embedding_weights:0',),
+ tuple([v.name for v in global_vars]))
with _initialized_session():
self.assertAllEqual(embedding_values, global_vars[0].eval())
self.assertAllEqual(expected_lookups, embedding_lookup.eval())
@@ -4823,8 +4862,9 @@ class EmbeddingColumnTest(test.TestCase):
}), weight_collections=('my_vars',))
# Assert expected embedding variable and lookups.
- self.assertItemsEqual(
- [], ops.get_collection(ops.GraphKeys.GLOBAL_VARIABLES))
+ global_vars = ops.get_collection(ops.GraphKeys.GLOBAL_VARIABLES)
+ self.assertItemsEqual(('embedding_weights:0',),
+ tuple([v.name for v in global_vars]))
my_vars = ops.get_collection('my_vars')
self.assertItemsEqual(
('embedding_weights:0',), tuple([v.name for v in my_vars]))
@@ -5243,14 +5283,12 @@ class SharedEmbeddingColumnTest(test.TestCase):
self.assertEqual(embedding_dimension, embedding_column_b.dimension)
self.assertEqual('mean', embedding_column_a.combiner)
self.assertEqual('mean', embedding_column_b.combiner)
- self.assertIsNotNone(embedding_column_a.initializer)
- self.assertIsNotNone(embedding_column_b.initializer)
self.assertIsNone(embedding_column_a.ckpt_to_load_from)
self.assertIsNone(embedding_column_b.ckpt_to_load_from)
self.assertEqual('aaa_bbb_shared_embedding',
- embedding_column_a.shared_embedding_collection_name)
+ embedding_column_a.var_scope_name)
self.assertEqual('aaa_bbb_shared_embedding',
- embedding_column_b.shared_embedding_collection_name)
+ embedding_column_b.var_scope_name)
self.assertIsNone(embedding_column_a.tensor_name_in_ckpt)
self.assertIsNone(embedding_column_b.tensor_name_in_ckpt)
self.assertIsNone(embedding_column_a.max_norm)
@@ -5296,12 +5334,10 @@ class SharedEmbeddingColumnTest(test.TestCase):
self.assertEqual(embedding_dimension, embedding_column_b.dimension)
self.assertEqual('my_combiner', embedding_column_a.combiner)
self.assertEqual('my_combiner', embedding_column_b.combiner)
- self.assertEqual('my_initializer', embedding_column_a.initializer())
- self.assertEqual('my_initializer', embedding_column_b.initializer())
self.assertEqual('shared_embedding_collection_name',
- embedding_column_a.shared_embedding_collection_name)
+ embedding_column_a.var_scope_name)
self.assertEqual('shared_embedding_collection_name',
- embedding_column_b.shared_embedding_collection_name)
+ embedding_column_b.var_scope_name)
self.assertEqual('my_ckpt', embedding_column_a.ckpt_to_load_from)
self.assertEqual('my_ckpt', embedding_column_b.ckpt_to_load_from)
self.assertEqual('my_ckpt_tensor', embedding_column_a.tensor_name_in_ckpt)
@@ -5351,9 +5387,8 @@ class SharedEmbeddingColumnTest(test.TestCase):
self.assertEqual(embedding_dimension, embedding_column_a.dimension)
self.assertEqual('my_combiner', embedding_column_a.combiner)
- self.assertEqual('my_initializer', embedding_column_a.initializer())
self.assertEqual('shared_embedding_collection_name',
- embedding_column_a.shared_embedding_collection_name)
+ embedding_column_a.var_scope_name)
self.assertEqual('my_ckpt', embedding_column_a.ckpt_to_load_from)
self.assertEqual('my_ckpt_tensor', embedding_column_a.tensor_name_in_ckpt)
self.assertEqual(42., embedding_column_a.max_norm)
@@ -5537,80 +5572,6 @@ class SharedEmbeddingColumnTest(test.TestCase):
self.assertAllEqual(expected_lookups_a, embedding_lookup_a.eval())
self.assertAllEqual(expected_lookups_b, embedding_lookup_b.eval())
- def test_get_dense_tensor_with_state(self):
- # Inputs.
- vocabulary_size = 3
- # -1 values are ignored.
- input_a = np.array([
- [2, -1, -1], # example 0, ids [2]
- [0, 1, -1]
- ]) # example 1, ids [0, 1]
- input_b = np.array([
- [0, -1, -1], # example 0, ids [0]
- [-1, -1, -1]
- ]) # example 1, ids []
- input_features = {'aaa': input_a, 'bbb': input_b}
-
- # Embedding variable.
- embedding_dimension = 2
- embedding_values = (
- (1., 2.), # id 0
- (3., 5.), # id 1
- (7., 11.) # id 2
- )
-
- def _initializer(shape, dtype, partition_info):
- self.assertAllEqual((vocabulary_size, embedding_dimension), shape)
- self.assertEqual(dtypes.float32, dtype)
- self.assertIsNone(partition_info)
- return embedding_values
-
- # Expected lookup result, using combiner='mean'.
- expected_lookups_a = (
- # example 0:
- (7., 11.), # ids [2], embedding = [7, 11]
- # example 1:
- (2., 3.5), # ids [0, 1], embedding = mean([1, 2] + [3, 5]) = [2, 3.5]
- )
- expected_lookups_b = (
- # example 0:
- (1., 2.), # ids [0], embedding = [1, 2]
- # example 1:
- (0., 0.), # ids [], embedding = [0, 0]
- )
-
- # Build columns.
- categorical_column_a = fc.categorical_column_with_identity(
- key='aaa', num_buckets=vocabulary_size)
- categorical_column_b = fc.categorical_column_with_identity(
- key='bbb', num_buckets=vocabulary_size)
- embedding_column_a, embedding_column_b = fc.shared_embedding_columns(
- [categorical_column_a, categorical_column_b],
- dimension=embedding_dimension,
- initializer=_initializer)
-
- # Create state.
- weight_collections = [
- ops.GraphKeys.GLOBAL_VARIABLES, ops.GraphKeys.MODEL_VARIABLES
- ]
- state = embedding_column_a._create_state(weight_collections)
-
- # Provide sparse input and get dense result.
- embedding_lookup_a = embedding_column_a._get_dense_tensor(
- _LazyBuilder(input_features), state=state)
- embedding_lookup_b = embedding_column_b._get_dense_tensor(
- _LazyBuilder(input_features), state=state)
-
- # Assert expected embedding variable and lookups.
- global_vars = ops.get_collection(ops.GraphKeys.GLOBAL_VARIABLES)
- self.assertItemsEqual(
- ('embedding_weights:0',), tuple([v.name for v in global_vars]))
- embedding_var = global_vars[0]
- with _initialized_session():
- self.assertAllEqual(embedding_values, embedding_var.eval())
- self.assertAllEqual(expected_lookups_a, embedding_lookup_a.eval())
- self.assertAllEqual(expected_lookups_b, embedding_lookup_b.eval())
-
def test_get_dense_tensor_placeholder_inputs(self):
# Inputs.
vocabulary_size = 3
@@ -5912,10 +5873,7 @@ class SharedEmbeddingColumnTest(test.TestCase):
tuple([v.name for v in trainable_vars]))
else:
self.assertItemsEqual([], tuple([v.name for v in trainable_vars]))
- shared_embedding_vars = ops.get_collection('aaa_bbb_shared_embedding')
- self.assertItemsEqual(
- ['input_layer/aaa_bbb_shared_embedding/embedding_weights:0'],
- tuple([v.name for v in shared_embedding_vars]))
+ shared_embedding_vars = global_vars
with _initialized_session():
self.assertAllEqual(embedding_values, shared_embedding_vars[0].eval())
self.assertAllEqual(expected_lookups, input_layer.eval())