aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/python/feature_column
diff options
context:
space:
mode:
authorGravatar Rohan Jain <rohanj@google.com>2018-06-22 17:11:45 -0700
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2018-06-22 17:14:29 -0700
commit3676ca9f6381a3a7ef55122305ecc8aa0048e55c (patch)
tree2ce204f3e36ee581c7971ef1c3c82dba260015f6 /tensorflow/python/feature_column
parentaa63c6b0209c1194a392e4327f6211c7c13f40dc (diff)
Automated g4 rollback of changelist 201585400
PiperOrigin-RevId: 201764268
Diffstat (limited to 'tensorflow/python/feature_column')
-rw-r--r--tensorflow/python/feature_column/feature_column.py123
-rw-r--r--tensorflow/python/feature_column/feature_column_test.py10
2 files changed, 68 insertions, 65 deletions
diff --git a/tensorflow/python/feature_column/feature_column.py b/tensorflow/python/feature_column/feature_column.py
index 7d070dc27c..40219e4b34 100644
--- a/tensorflow/python/feature_column/feature_column.py
+++ b/tensorflow/python/feature_column/feature_column.py
@@ -468,13 +468,25 @@ def linear_model(features,
def _add_to_collections(var, weight_collections):
- # TODO(rohanj): Explore adding a _get_variable_list method on `Variable`
- # so that we don't have to do this check.
- if isinstance(var, variables.PartitionedVariable):
- for constituent_var in list(var):
- ops.add_to_collections(weight_collections, constituent_var)
- else:
- ops.add_to_collections(weight_collections, var)
+ """Adds a var to the list of weight_collections provided.
+
+ Handles the case for partitioned and non-partitioned variables.
+
+ Args:
+ var: A variable or Partitioned Variable.
+ weight_collections: List of collections to add variable to.
+ """
+ for weight_collection in weight_collections:
+ # The layer self.add_variable call already adds it to GLOBAL_VARIABLES.
+ if weight_collection == ops.GraphKeys.GLOBAL_VARIABLES:
+ continue
+ # TODO(rohanj): Explore adding a _get_variable_list method on `Variable`
+ # so that we don't have to do this check.
+ if isinstance(var, variables.PartitionedVariable):
+ for constituent_var in list(var):
+ ops.add_to_collection(weight_collection, constituent_var)
+ else:
+ ops.add_to_collection(weight_collection, var)
class _FCLinearWrapper(base.Layer):
@@ -585,6 +597,8 @@ class _LinearModel(training.Model):
self._feature_columns = _normalize_feature_columns(
feature_columns)
self._weight_collections = list(weight_collections or [])
+ if ops.GraphKeys.GLOBAL_VARIABLES not in self._weight_collections:
+ self._weight_collections.append(ops.GraphKeys.GLOBAL_VARIABLES)
if ops.GraphKeys.MODEL_VARIABLES not in self._weight_collections:
self._weight_collections.append(ops.GraphKeys.MODEL_VARIABLES)
@@ -973,7 +987,12 @@ def shared_embedding_columns(
ValueError: if exactly one of `ckpt_to_load_from` and `tensor_name_in_ckpt`
is specified.
ValueError: if `initializer` is specified and is not callable.
+ RuntimeError: if eager execution is enabled.
"""
+ if context.executing_eagerly():
+ raise RuntimeError('shared_embedding_columns are not supported when eager '
+ 'execution is enabled.')
+
if (dimension is None) or (dimension < 1):
raise ValueError('Invalid dimension {}.'.format(dimension))
if (ckpt_to_load_from is None) != (tensor_name_in_ckpt is None):
@@ -1018,16 +1037,6 @@ def shared_embedding_columns(
shared_embedding_collection_name = '_'.join(c.name for c in sorted_columns)
shared_embedding_collection_name += '_shared_embedding'
- # Create the state (_SharedEmbeddingColumnLayer) here.
- embedding_shape = num_buckets, dimension
-
- shared_embedding_column_layer = _EmbeddingColumnLayer(
- embedding_shape=embedding_shape,
- initializer=initializer,
- weight_collections=[],
- trainable=trainable,
- name=shared_embedding_collection_name)
-
result = []
for column in categorical_columns:
result.append(
@@ -1036,16 +1045,12 @@ def shared_embedding_columns(
initializer=initializer,
dimension=dimension,
combiner=combiner,
- var_scope_name=shared_embedding_collection_name,
+ shared_embedding_collection_name=shared_embedding_collection_name,
ckpt_to_load_from=ckpt_to_load_from,
tensor_name_in_ckpt=tensor_name_in_ckpt,
max_norm=max_norm,
trainable=trainable))
- for single_result in result:
- single_result._set_layer(shared_embedding_column_layer) # pylint: disable=protected-access
- single_result._set_all_columns(result) # pylint: disable=protected-access
-
return result
@@ -1865,11 +1870,8 @@ class _EmbeddingColumnLayer(base.Layer):
dtype=dtypes.float32,
initializer=self._initializer,
trainable=self.trainable)
- # self.add_variable already appends to GLOBAL_VARIABLES collection.
if self._weight_collections and not context.executing_eagerly():
- for weight_collection in self._weight_collections:
- if weight_collection != ops.GraphKeys.GLOBAL_VARIABLES:
- _add_to_collections(self._embedding_weight_var, [weight_collection])
+ _add_to_collections(self._embedding_weight_var, self._weight_collections)
self.built = True
def call(self, _):
@@ -2651,8 +2653,8 @@ class _SharedEmbeddingColumn(
collections.namedtuple(
'_SharedEmbeddingColumn',
('categorical_column', 'dimension', 'combiner', 'initializer',
- 'var_scope_name', 'ckpt_to_load_from', 'tensor_name_in_ckpt',
- 'max_norm', 'trainable'))):
+ 'shared_embedding_collection_name', 'ckpt_to_load_from',
+ 'tensor_name_in_ckpt', 'max_norm', 'trainable'))):
"""See `embedding_column`."""
@property
@@ -2663,7 +2665,7 @@ class _SharedEmbeddingColumn(
@property
def _var_scope_name(self):
- return self.var_scope_name
+ return self.shared_embedding_collection_name
@property
def _parse_example_spec(self):
@@ -2672,22 +2674,6 @@ class _SharedEmbeddingColumn(
def _transform_feature(self, inputs):
return inputs.get(self.categorical_column)
- def _set_layer(self, layer):
- self._layer = layer
-
- def _set_all_columns(self, all_columns):
- self._all_columns = all_columns
-
- def _reset_config(self):
- config = self._layer.get_config()
- config['embedding_shape'] = (
- self.categorical_column._num_buckets, # pylint: disable=protected-access
- self.dimension)
- config['initializer'] = self.initializer
- self._layer = self._layer.__class__.from_config(config)
- for column in self._all_columns:
- column._set_layer(self._layer) # pylint: disable=protected-access
-
@property
def _variable_shape(self):
if not hasattr(self, '_shape'):
@@ -2709,19 +2695,38 @@ class _SharedEmbeddingColumn(
sparse_ids = sparse_tensors.id_tensor
sparse_weights = sparse_tensors.weight_tensor
- self._layer.set_weight_collections(weight_collections)
- embedding_weights = self._layer(
- None, scope=variable_scope.get_variable_scope())
- # If we're in graph mode and this is called with a different graph,
- # then we should reset.
- if not context.executing_eagerly() and (
- ops.get_default_graph() !=
- _get_graph_for_variable(embedding_weights)):
- self._reset_config()
- self._layer.set_weight_collections(weight_collections)
- embedding_weights = self._layer(
- None, scope=variable_scope.get_variable_scope())
-
+ embedding_shape = (self.categorical_column._num_buckets, self.dimension) # pylint: disable=protected-access
+ shared_embedding_collection = ops.get_collection(
+ self.shared_embedding_collection_name)
+ if shared_embedding_collection:
+ if len(shared_embedding_collection) > 1:
+ raise ValueError(
+ 'Collection {} can only contain one variable. '
+ 'Suggested fix A: Choose a unique name for this collection. '
+ 'Suggested fix B: Do not add any variables to this collection. '
+ 'The feature_column library already adds a variable under the '
+ 'hood.'.format(shared_embedding_collection))
+ embedding_weights = shared_embedding_collection[0]
+ if embedding_weights.get_shape() != embedding_shape:
+ raise ValueError(
+ 'Shared embedding collection {} contains variable {} of '
+ 'unexpected shape {}. Expected shape is {}. '
+ 'Suggested fix A: Choose a unique name for this collection. '
+ 'Suggested fix B: Do not add any variables to this collection. '
+ 'The feature_column library already adds a variable under the '
+ 'hood.'.format(self.shared_embedding_collection_name,
+ embedding_weights.name,
+ embedding_weights.get_shape(), embedding_shape))
+ else:
+ embedding_weights = variable_scope.get_variable(
+ name='embedding_weights',
+ shape=embedding_shape,
+ dtype=dtypes.float32,
+ initializer=self.initializer,
+ trainable=self.trainable and trainable,
+ collections=weight_collections)
+ ops.add_to_collection(self.shared_embedding_collection_name,
+ embedding_weights)
if self.ckpt_to_load_from is not None:
to_restore = embedding_weights
if isinstance(to_restore, variables.PartitionedVariable):
@@ -3581,5 +3586,3 @@ class _SequenceCategoricalColumn(
weight_tensor,
shape=array_ops.concat([weight_tensor.dense_shape, [1]], axis=0))
return _CategoricalColumn.IdWeightPair(id_tensor, weight_tensor)
-
-
diff --git a/tensorflow/python/feature_column/feature_column_test.py b/tensorflow/python/feature_column/feature_column_test.py
index a9013d8e42..511205451c 100644
--- a/tensorflow/python/feature_column/feature_column_test.py
+++ b/tensorflow/python/feature_column/feature_column_test.py
@@ -5350,9 +5350,9 @@ class SharedEmbeddingColumnTest(test.TestCase):
self.assertIsNone(embedding_column_a.ckpt_to_load_from)
self.assertIsNone(embedding_column_b.ckpt_to_load_from)
self.assertEqual('aaa_bbb_shared_embedding',
- embedding_column_a.var_scope_name)
+ embedding_column_a.shared_embedding_collection_name)
self.assertEqual('aaa_bbb_shared_embedding',
- embedding_column_b.var_scope_name)
+ embedding_column_b.shared_embedding_collection_name)
self.assertIsNone(embedding_column_a.tensor_name_in_ckpt)
self.assertIsNone(embedding_column_b.tensor_name_in_ckpt)
self.assertIsNone(embedding_column_a.max_norm)
@@ -5399,9 +5399,9 @@ class SharedEmbeddingColumnTest(test.TestCase):
self.assertEqual('my_combiner', embedding_column_a.combiner)
self.assertEqual('my_combiner', embedding_column_b.combiner)
self.assertEqual('shared_embedding_collection_name',
- embedding_column_a.var_scope_name)
+ embedding_column_a.shared_embedding_collection_name)
self.assertEqual('shared_embedding_collection_name',
- embedding_column_b.var_scope_name)
+ embedding_column_b.shared_embedding_collection_name)
self.assertEqual('my_ckpt', embedding_column_a.ckpt_to_load_from)
self.assertEqual('my_ckpt', embedding_column_b.ckpt_to_load_from)
self.assertEqual('my_ckpt_tensor', embedding_column_a.tensor_name_in_ckpt)
@@ -5452,7 +5452,7 @@ class SharedEmbeddingColumnTest(test.TestCase):
self.assertEqual(embedding_dimension, embedding_column_a.dimension)
self.assertEqual('my_combiner', embedding_column_a.combiner)
self.assertEqual('shared_embedding_collection_name',
- embedding_column_a.var_scope_name)
+ embedding_column_a.shared_embedding_collection_name)
self.assertEqual('my_ckpt', embedding_column_a.ckpt_to_load_from)
self.assertEqual('my_ckpt_tensor', embedding_column_a.tensor_name_in_ckpt)
self.assertEqual(42., embedding_column_a.max_norm)