aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/python/feature_column
diff options
context:
space:
mode:
authorGravatar Rohan Jain <rohanj@google.com>2018-06-21 14:20:43 -0700
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2018-06-21 14:23:33 -0700
commit25be72010a2e87e776814d2feb054d9ce43d7884 (patch)
treee9b89af18c51af915ddcdc75ad07b7b6e8c2becd /tensorflow/python/feature_column
parent385ea351a11a73d464dab1e239c66dacb4bf2bc0 (diff)
Automated g4 rollback of changelist 201230316
PiperOrigin-RevId: 201585400
Diffstat (limited to 'tensorflow/python/feature_column')
-rw-r--r--tensorflow/python/feature_column/feature_column.py123
-rw-r--r--tensorflow/python/feature_column/feature_column_test.py10
2 files changed, 65 insertions, 68 deletions
diff --git a/tensorflow/python/feature_column/feature_column.py b/tensorflow/python/feature_column/feature_column.py
index 40219e4b34..7d070dc27c 100644
--- a/tensorflow/python/feature_column/feature_column.py
+++ b/tensorflow/python/feature_column/feature_column.py
@@ -468,25 +468,13 @@ def linear_model(features,
def _add_to_collections(var, weight_collections):
- """Adds a var to the list of weight_collections provided.
-
- Handles the case for partitioned and non-partitioned variables.
-
- Args:
- var: A variable or Partitioned Variable.
- weight_collections: List of collections to add variable to.
- """
- for weight_collection in weight_collections:
- # The layer self.add_variable call already adds it to GLOBAL_VARIABLES.
- if weight_collection == ops.GraphKeys.GLOBAL_VARIABLES:
- continue
- # TODO(rohanj): Explore adding a _get_variable_list method on `Variable`
- # so that we don't have to do this check.
- if isinstance(var, variables.PartitionedVariable):
- for constituent_var in list(var):
- ops.add_to_collection(weight_collection, constituent_var)
- else:
- ops.add_to_collection(weight_collection, var)
+ # TODO(rohanj): Explore adding a _get_variable_list method on `Variable`
+ # so that we don't have to do this check.
+ if isinstance(var, variables.PartitionedVariable):
+ for constituent_var in list(var):
+ ops.add_to_collections(weight_collections, constituent_var)
+ else:
+ ops.add_to_collections(weight_collections, var)
class _FCLinearWrapper(base.Layer):
@@ -597,8 +585,6 @@ class _LinearModel(training.Model):
self._feature_columns = _normalize_feature_columns(
feature_columns)
self._weight_collections = list(weight_collections or [])
- if ops.GraphKeys.GLOBAL_VARIABLES not in self._weight_collections:
- self._weight_collections.append(ops.GraphKeys.GLOBAL_VARIABLES)
if ops.GraphKeys.MODEL_VARIABLES not in self._weight_collections:
self._weight_collections.append(ops.GraphKeys.MODEL_VARIABLES)
@@ -987,12 +973,7 @@ def shared_embedding_columns(
ValueError: if exactly one of `ckpt_to_load_from` and `tensor_name_in_ckpt`
is specified.
ValueError: if `initializer` is specified and is not callable.
- RuntimeError: if eager execution is enabled.
"""
- if context.executing_eagerly():
- raise RuntimeError('shared_embedding_columns are not supported when eager '
- 'execution is enabled.')
-
if (dimension is None) or (dimension < 1):
raise ValueError('Invalid dimension {}.'.format(dimension))
if (ckpt_to_load_from is None) != (tensor_name_in_ckpt is None):
@@ -1037,6 +1018,16 @@ def shared_embedding_columns(
shared_embedding_collection_name = '_'.join(c.name for c in sorted_columns)
shared_embedding_collection_name += '_shared_embedding'
+ # Create the state (_SharedEmbeddingColumnLayer) here.
+ embedding_shape = num_buckets, dimension
+
+ shared_embedding_column_layer = _EmbeddingColumnLayer(
+ embedding_shape=embedding_shape,
+ initializer=initializer,
+ weight_collections=[],
+ trainable=trainable,
+ name=shared_embedding_collection_name)
+
result = []
for column in categorical_columns:
result.append(
@@ -1045,12 +1036,16 @@ def shared_embedding_columns(
initializer=initializer,
dimension=dimension,
combiner=combiner,
- shared_embedding_collection_name=shared_embedding_collection_name,
+ var_scope_name=shared_embedding_collection_name,
ckpt_to_load_from=ckpt_to_load_from,
tensor_name_in_ckpt=tensor_name_in_ckpt,
max_norm=max_norm,
trainable=trainable))
+ for single_result in result:
+ single_result._set_layer(shared_embedding_column_layer) # pylint: disable=protected-access
+ single_result._set_all_columns(result) # pylint: disable=protected-access
+
return result
@@ -1870,8 +1865,11 @@ class _EmbeddingColumnLayer(base.Layer):
dtype=dtypes.float32,
initializer=self._initializer,
trainable=self.trainable)
+ # self.add_variable already appends to GLOBAL_VARIABLES collection.
if self._weight_collections and not context.executing_eagerly():
- _add_to_collections(self._embedding_weight_var, self._weight_collections)
+ for weight_collection in self._weight_collections:
+ if weight_collection != ops.GraphKeys.GLOBAL_VARIABLES:
+ _add_to_collections(self._embedding_weight_var, [weight_collection])
self.built = True
def call(self, _):
@@ -2653,8 +2651,8 @@ class _SharedEmbeddingColumn(
collections.namedtuple(
'_SharedEmbeddingColumn',
('categorical_column', 'dimension', 'combiner', 'initializer',
- 'shared_embedding_collection_name', 'ckpt_to_load_from',
- 'tensor_name_in_ckpt', 'max_norm', 'trainable'))):
+ 'var_scope_name', 'ckpt_to_load_from', 'tensor_name_in_ckpt',
+ 'max_norm', 'trainable'))):
"""See `embedding_column`."""
@property
@@ -2665,7 +2663,7 @@ class _SharedEmbeddingColumn(
@property
def _var_scope_name(self):
- return self.shared_embedding_collection_name
+ return self.var_scope_name
@property
def _parse_example_spec(self):
@@ -2674,6 +2672,22 @@ class _SharedEmbeddingColumn(
def _transform_feature(self, inputs):
return inputs.get(self.categorical_column)
+ def _set_layer(self, layer):
+ self._layer = layer
+
+ def _set_all_columns(self, all_columns):
+ self._all_columns = all_columns
+
+ def _reset_config(self):
+ config = self._layer.get_config()
+ config['embedding_shape'] = (
+ self.categorical_column._num_buckets, # pylint: disable=protected-access
+ self.dimension)
+ config['initializer'] = self.initializer
+ self._layer = self._layer.__class__.from_config(config)
+ for column in self._all_columns:
+ column._set_layer(self._layer) # pylint: disable=protected-access
+
@property
def _variable_shape(self):
if not hasattr(self, '_shape'):
@@ -2695,38 +2709,19 @@ class _SharedEmbeddingColumn(
sparse_ids = sparse_tensors.id_tensor
sparse_weights = sparse_tensors.weight_tensor
- embedding_shape = (self.categorical_column._num_buckets, self.dimension) # pylint: disable=protected-access
- shared_embedding_collection = ops.get_collection(
- self.shared_embedding_collection_name)
- if shared_embedding_collection:
- if len(shared_embedding_collection) > 1:
- raise ValueError(
- 'Collection {} can only contain one variable. '
- 'Suggested fix A: Choose a unique name for this collection. '
- 'Suggested fix B: Do not add any variables to this collection. '
- 'The feature_column library already adds a variable under the '
- 'hood.'.format(shared_embedding_collection))
- embedding_weights = shared_embedding_collection[0]
- if embedding_weights.get_shape() != embedding_shape:
- raise ValueError(
- 'Shared embedding collection {} contains variable {} of '
- 'unexpected shape {}. Expected shape is {}. '
- 'Suggested fix A: Choose a unique name for this collection. '
- 'Suggested fix B: Do not add any variables to this collection. '
- 'The feature_column library already adds a variable under the '
- 'hood.'.format(self.shared_embedding_collection_name,
- embedding_weights.name,
- embedding_weights.get_shape(), embedding_shape))
- else:
- embedding_weights = variable_scope.get_variable(
- name='embedding_weights',
- shape=embedding_shape,
- dtype=dtypes.float32,
- initializer=self.initializer,
- trainable=self.trainable and trainable,
- collections=weight_collections)
- ops.add_to_collection(self.shared_embedding_collection_name,
- embedding_weights)
+ self._layer.set_weight_collections(weight_collections)
+ embedding_weights = self._layer(
+ None, scope=variable_scope.get_variable_scope())
+ # If we're in graph mode and this is called with a different graph,
+ # then we should reset.
+ if not context.executing_eagerly() and (
+ ops.get_default_graph() !=
+ _get_graph_for_variable(embedding_weights)):
+ self._reset_config()
+ self._layer.set_weight_collections(weight_collections)
+ embedding_weights = self._layer(
+ None, scope=variable_scope.get_variable_scope())
+
if self.ckpt_to_load_from is not None:
to_restore = embedding_weights
if isinstance(to_restore, variables.PartitionedVariable):
@@ -3586,3 +3581,5 @@ class _SequenceCategoricalColumn(
weight_tensor,
shape=array_ops.concat([weight_tensor.dense_shape, [1]], axis=0))
return _CategoricalColumn.IdWeightPair(id_tensor, weight_tensor)
+
+
diff --git a/tensorflow/python/feature_column/feature_column_test.py b/tensorflow/python/feature_column/feature_column_test.py
index dc3dde6710..d0023ffdd7 100644
--- a/tensorflow/python/feature_column/feature_column_test.py
+++ b/tensorflow/python/feature_column/feature_column_test.py
@@ -5350,9 +5350,9 @@ class SharedEmbeddingColumnTest(test.TestCase):
self.assertIsNone(embedding_column_a.ckpt_to_load_from)
self.assertIsNone(embedding_column_b.ckpt_to_load_from)
self.assertEqual('aaa_bbb_shared_embedding',
- embedding_column_a.shared_embedding_collection_name)
+ embedding_column_a.var_scope_name)
self.assertEqual('aaa_bbb_shared_embedding',
- embedding_column_b.shared_embedding_collection_name)
+ embedding_column_b.var_scope_name)
self.assertIsNone(embedding_column_a.tensor_name_in_ckpt)
self.assertIsNone(embedding_column_b.tensor_name_in_ckpt)
self.assertIsNone(embedding_column_a.max_norm)
@@ -5399,9 +5399,9 @@ class SharedEmbeddingColumnTest(test.TestCase):
self.assertEqual('my_combiner', embedding_column_a.combiner)
self.assertEqual('my_combiner', embedding_column_b.combiner)
self.assertEqual('shared_embedding_collection_name',
- embedding_column_a.shared_embedding_collection_name)
+ embedding_column_a.var_scope_name)
self.assertEqual('shared_embedding_collection_name',
- embedding_column_b.shared_embedding_collection_name)
+ embedding_column_b.var_scope_name)
self.assertEqual('my_ckpt', embedding_column_a.ckpt_to_load_from)
self.assertEqual('my_ckpt', embedding_column_b.ckpt_to_load_from)
self.assertEqual('my_ckpt_tensor', embedding_column_a.tensor_name_in_ckpt)
@@ -5452,7 +5452,7 @@ class SharedEmbeddingColumnTest(test.TestCase):
self.assertEqual(embedding_dimension, embedding_column_a.dimension)
self.assertEqual('my_combiner', embedding_column_a.combiner)
self.assertEqual('shared_embedding_collection_name',
- embedding_column_a.shared_embedding_collection_name)
+ embedding_column_a.var_scope_name)
self.assertEqual('my_ckpt', embedding_column_a.ckpt_to_load_from)
self.assertEqual('my_ckpt_tensor', embedding_column_a.tensor_name_in_ckpt)
self.assertEqual(42., embedding_column_a.max_norm)