diff options
author | Rohan Jain <rohanj@google.com> | 2018-06-21 14:20:43 -0700 |
---|---|---|
committer | TensorFlower Gardener <gardener@tensorflow.org> | 2018-06-21 14:23:33 -0700 |
commit | 25be72010a2e87e776814d2feb054d9ce43d7884 (patch) | |
tree | e9b89af18c51af915ddcdc75ad07b7b6e8c2becd /tensorflow/python/feature_column | |
parent | 385ea351a11a73d464dab1e239c66dacb4bf2bc0 (diff) |
Automated g4 rollback of changelist 201230316
PiperOrigin-RevId: 201585400
Diffstat (limited to 'tensorflow/python/feature_column')
-rw-r--r-- | tensorflow/python/feature_column/feature_column.py | 123 | ||||
-rw-r--r-- | tensorflow/python/feature_column/feature_column_test.py | 10 |
2 files changed, 65 insertions, 68 deletions
diff --git a/tensorflow/python/feature_column/feature_column.py b/tensorflow/python/feature_column/feature_column.py index 40219e4b34..7d070dc27c 100644 --- a/tensorflow/python/feature_column/feature_column.py +++ b/tensorflow/python/feature_column/feature_column.py @@ -468,25 +468,13 @@ def linear_model(features, def _add_to_collections(var, weight_collections): - """Adds a var to the list of weight_collections provided. - - Handles the case for partitioned and non-partitioned variables. - - Args: - var: A variable or Partitioned Variable. - weight_collections: List of collections to add variable to. - """ - for weight_collection in weight_collections: - # The layer self.add_variable call already adds it to GLOBAL_VARIABLES. - if weight_collection == ops.GraphKeys.GLOBAL_VARIABLES: - continue - # TODO(rohanj): Explore adding a _get_variable_list method on `Variable` - # so that we don't have to do this check. - if isinstance(var, variables.PartitionedVariable): - for constituent_var in list(var): - ops.add_to_collection(weight_collection, constituent_var) - else: - ops.add_to_collection(weight_collection, var) + # TODO(rohanj): Explore adding a _get_variable_list method on `Variable` + # so that we don't have to do this check. + if isinstance(var, variables.PartitionedVariable): + for constituent_var in list(var): + ops.add_to_collections(weight_collections, constituent_var) + else: + ops.add_to_collections(weight_collections, var) class _FCLinearWrapper(base.Layer): @@ -597,8 +585,6 @@ class _LinearModel(training.Model): self._feature_columns = _normalize_feature_columns( feature_columns) self._weight_collections = list(weight_collections or []) - if ops.GraphKeys.GLOBAL_VARIABLES not in self._weight_collections: - self._weight_collections.append(ops.GraphKeys.GLOBAL_VARIABLES) if ops.GraphKeys.MODEL_VARIABLES not in self._weight_collections: self._weight_collections.append(ops.GraphKeys.MODEL_VARIABLES) @@ -987,12 +973,7 @@ def shared_embedding_columns( ValueError: if exactly one of `ckpt_to_load_from` and `tensor_name_in_ckpt` is specified. ValueError: if `initializer` is specified and is not callable. - RuntimeError: if eager execution is enabled. """ - if context.executing_eagerly(): - raise RuntimeError('shared_embedding_columns are not supported when eager ' - 'execution is enabled.') - if (dimension is None) or (dimension < 1): raise ValueError('Invalid dimension {}.'.format(dimension)) if (ckpt_to_load_from is None) != (tensor_name_in_ckpt is None): @@ -1037,6 +1018,16 @@ def shared_embedding_columns( shared_embedding_collection_name = '_'.join(c.name for c in sorted_columns) shared_embedding_collection_name += '_shared_embedding' + # Create the state (_SharedEmbeddingColumnLayer) here. + embedding_shape = num_buckets, dimension + + shared_embedding_column_layer = _EmbeddingColumnLayer( + embedding_shape=embedding_shape, + initializer=initializer, + weight_collections=[], + trainable=trainable, + name=shared_embedding_collection_name) + result = [] for column in categorical_columns: result.append( @@ -1045,12 +1036,16 @@ def shared_embedding_columns( initializer=initializer, dimension=dimension, combiner=combiner, - shared_embedding_collection_name=shared_embedding_collection_name, + var_scope_name=shared_embedding_collection_name, ckpt_to_load_from=ckpt_to_load_from, tensor_name_in_ckpt=tensor_name_in_ckpt, max_norm=max_norm, trainable=trainable)) + for single_result in result: + single_result._set_layer(shared_embedding_column_layer) # pylint: disable=protected-access + single_result._set_all_columns(result) # pylint: disable=protected-access + return result @@ -1870,8 +1865,11 @@ class _EmbeddingColumnLayer(base.Layer): dtype=dtypes.float32, initializer=self._initializer, trainable=self.trainable) + # self.add_variable already appends to GLOBAL_VARIABLES collection. if self._weight_collections and not context.executing_eagerly(): - _add_to_collections(self._embedding_weight_var, self._weight_collections) + for weight_collection in self._weight_collections: + if weight_collection != ops.GraphKeys.GLOBAL_VARIABLES: + _add_to_collections(self._embedding_weight_var, [weight_collection]) self.built = True def call(self, _): @@ -2653,8 +2651,8 @@ class _SharedEmbeddingColumn( collections.namedtuple( '_SharedEmbeddingColumn', ('categorical_column', 'dimension', 'combiner', 'initializer', - 'shared_embedding_collection_name', 'ckpt_to_load_from', - 'tensor_name_in_ckpt', 'max_norm', 'trainable'))): + 'var_scope_name', 'ckpt_to_load_from', 'tensor_name_in_ckpt', + 'max_norm', 'trainable'))): """See `embedding_column`.""" @property @@ -2665,7 +2663,7 @@ class _SharedEmbeddingColumn( @property def _var_scope_name(self): - return self.shared_embedding_collection_name + return self.var_scope_name @property def _parse_example_spec(self): @@ -2674,6 +2672,22 @@ class _SharedEmbeddingColumn( def _transform_feature(self, inputs): return inputs.get(self.categorical_column) + def _set_layer(self, layer): + self._layer = layer + + def _set_all_columns(self, all_columns): + self._all_columns = all_columns + + def _reset_config(self): + config = self._layer.get_config() + config['embedding_shape'] = ( + self.categorical_column._num_buckets, # pylint: disable=protected-access + self.dimension) + config['initializer'] = self.initializer + self._layer = self._layer.__class__.from_config(config) + for column in self._all_columns: + column._set_layer(self._layer) # pylint: disable=protected-access + @property def _variable_shape(self): if not hasattr(self, '_shape'): @@ -2695,38 +2709,19 @@ class _SharedEmbeddingColumn( sparse_ids = sparse_tensors.id_tensor sparse_weights = sparse_tensors.weight_tensor - embedding_shape = (self.categorical_column._num_buckets, self.dimension) # pylint: disable=protected-access - shared_embedding_collection = ops.get_collection( - self.shared_embedding_collection_name) - if shared_embedding_collection: - if len(shared_embedding_collection) > 1: - raise ValueError( - 'Collection {} can only contain one variable. ' - 'Suggested fix A: Choose a unique name for this collection. ' - 'Suggested fix B: Do not add any variables to this collection. ' - 'The feature_column library already adds a variable under the ' - 'hood.'.format(shared_embedding_collection)) - embedding_weights = shared_embedding_collection[0] - if embedding_weights.get_shape() != embedding_shape: - raise ValueError( - 'Shared embedding collection {} contains variable {} of ' - 'unexpected shape {}. Expected shape is {}. ' - 'Suggested fix A: Choose a unique name for this collection. ' - 'Suggested fix B: Do not add any variables to this collection. ' - 'The feature_column library already adds a variable under the ' - 'hood.'.format(self.shared_embedding_collection_name, - embedding_weights.name, - embedding_weights.get_shape(), embedding_shape)) - else: - embedding_weights = variable_scope.get_variable( - name='embedding_weights', - shape=embedding_shape, - dtype=dtypes.float32, - initializer=self.initializer, - trainable=self.trainable and trainable, - collections=weight_collections) - ops.add_to_collection(self.shared_embedding_collection_name, - embedding_weights) + self._layer.set_weight_collections(weight_collections) + embedding_weights = self._layer( + None, scope=variable_scope.get_variable_scope()) + # If we're in graph mode and this is called with a different graph, + # then we should reset. + if not context.executing_eagerly() and ( + ops.get_default_graph() != + _get_graph_for_variable(embedding_weights)): + self._reset_config() + self._layer.set_weight_collections(weight_collections) + embedding_weights = self._layer( + None, scope=variable_scope.get_variable_scope()) + if self.ckpt_to_load_from is not None: to_restore = embedding_weights if isinstance(to_restore, variables.PartitionedVariable): @@ -3586,3 +3581,5 @@ class _SequenceCategoricalColumn( weight_tensor, shape=array_ops.concat([weight_tensor.dense_shape, [1]], axis=0)) return _CategoricalColumn.IdWeightPair(id_tensor, weight_tensor) + + diff --git a/tensorflow/python/feature_column/feature_column_test.py b/tensorflow/python/feature_column/feature_column_test.py index dc3dde6710..d0023ffdd7 100644 --- a/tensorflow/python/feature_column/feature_column_test.py +++ b/tensorflow/python/feature_column/feature_column_test.py @@ -5350,9 +5350,9 @@ class SharedEmbeddingColumnTest(test.TestCase): self.assertIsNone(embedding_column_a.ckpt_to_load_from) self.assertIsNone(embedding_column_b.ckpt_to_load_from) self.assertEqual('aaa_bbb_shared_embedding', - embedding_column_a.shared_embedding_collection_name) + embedding_column_a.var_scope_name) self.assertEqual('aaa_bbb_shared_embedding', - embedding_column_b.shared_embedding_collection_name) + embedding_column_b.var_scope_name) self.assertIsNone(embedding_column_a.tensor_name_in_ckpt) self.assertIsNone(embedding_column_b.tensor_name_in_ckpt) self.assertIsNone(embedding_column_a.max_norm) @@ -5399,9 +5399,9 @@ class SharedEmbeddingColumnTest(test.TestCase): self.assertEqual('my_combiner', embedding_column_a.combiner) self.assertEqual('my_combiner', embedding_column_b.combiner) self.assertEqual('shared_embedding_collection_name', - embedding_column_a.shared_embedding_collection_name) + embedding_column_a.var_scope_name) self.assertEqual('shared_embedding_collection_name', - embedding_column_b.shared_embedding_collection_name) + embedding_column_b.var_scope_name) self.assertEqual('my_ckpt', embedding_column_a.ckpt_to_load_from) self.assertEqual('my_ckpt', embedding_column_b.ckpt_to_load_from) self.assertEqual('my_ckpt_tensor', embedding_column_a.tensor_name_in_ckpt) @@ -5452,7 +5452,7 @@ class SharedEmbeddingColumnTest(test.TestCase): self.assertEqual(embedding_dimension, embedding_column_a.dimension) self.assertEqual('my_combiner', embedding_column_a.combiner) self.assertEqual('shared_embedding_collection_name', - embedding_column_a.shared_embedding_collection_name) + embedding_column_a.var_scope_name) self.assertEqual('my_ckpt', embedding_column_a.ckpt_to_load_from) self.assertEqual('my_ckpt_tensor', embedding_column_a.tensor_name_in_ckpt) self.assertEqual(42., embedding_column_a.max_norm) |