diff options
author | Rohan Jain <rohanj@google.com> | 2018-05-31 20:54:27 -0700 |
---|---|---|
committer | TensorFlower Gardener <gardener@tensorflow.org> | 2018-05-31 20:57:30 -0700 |
commit | 54b20c4be0372fb14ec9a289e4d7de7f67c03ff6 (patch) | |
tree | 1dfcd43865d1401c87f40709cb50fd78caeb3fce /tensorflow/python/feature_column | |
parent | 8d1d8c1b436b84eeaede95c6ed53308a8a97cb08 (diff) |
Making sure that weight_collections are respected for shared_embedding_columns
PiperOrigin-RevId: 198823349
Diffstat (limited to 'tensorflow/python/feature_column')
-rw-r--r-- | tensorflow/python/feature_column/feature_column.py | 11 | ||||
-rw-r--r-- | tensorflow/python/feature_column/feature_column_test.py | 66 |
2 files changed, 77 insertions, 0 deletions
diff --git a/tensorflow/python/feature_column/feature_column.py b/tensorflow/python/feature_column/feature_column.py index 7aa46af828..59801efc26 100644 --- a/tensorflow/python/feature_column/feature_column.py +++ b/tensorflow/python/feature_column/feature_column.py @@ -1799,6 +1799,15 @@ class _EmbeddingColumnLayer(base.Layer): self._initializer = initializer self._weight_collections = weight_collections + def set_weight_collections(self, weight_collections): + """Sets the weight collections for the layer. + + Args: + weight_collections: A list of collection names to which the Variable will + be added. + """ + self._weight_collections = weight_collections + def build(self, _): self._embedding_weight_var = self.add_variable( name='embedding_weights', @@ -2604,6 +2613,7 @@ class _SharedEmbeddingColumn( sparse_ids = sparse_tensors.id_tensor sparse_weights = sparse_tensors.weight_tensor + self._layer.set_weight_collections(weight_collections) embedding_weights = self._layer( None, scope=variable_scope.get_variable_scope()) # If we're in graph mode and this is called with a different graph, @@ -2612,6 +2622,7 @@ class _SharedEmbeddingColumn( ops.get_default_graph() != _get_graph_for_variable(embedding_weights)): self._reset_config() + self._layer.set_weight_collections(weight_collections) embedding_weights = self._layer( None, scope=variable_scope.get_variable_scope()) diff --git a/tensorflow/python/feature_column/feature_column_test.py b/tensorflow/python/feature_column/feature_column_test.py index 0af7b9baa9..627430d6bc 100644 --- a/tensorflow/python/feature_column/feature_column_test.py +++ b/tensorflow/python/feature_column/feature_column_test.py @@ -5615,6 +5615,72 @@ class SharedEmbeddingColumnTest(test.TestCase): self.assertAllEqual(expected_lookups_a, embedding_lookup_a.eval()) self.assertAllEqual(expected_lookups_b, embedding_lookup_b.eval()) + def test_get_dense_tensor_weight_collections(self): + # Inputs. + vocabulary_size = 3 + # -1 values are ignored. + input_a = np.array([ + [2, -1, -1], # example 0, ids [2] + [0, 1, -1] + ]) # example 1, ids [0, 1] + input_b = np.array([ + [0, -1, -1], # example 0, ids [0] + [-1, -1, -1] + ]) # example 1, ids [] + input_features = {'aaa': input_a, 'bbb': input_b} + + # Embedding variable. + embedding_dimension = 2 + embedding_values = ( + (1., 2.), # id 0 + (3., 5.), # id 1 + (7., 11.) # id 2 + ) + + def _initializer(shape, dtype, partition_info): + self.assertAllEqual((vocabulary_size, embedding_dimension), shape) + self.assertEqual(dtypes.float32, dtype) + self.assertIsNone(partition_info) + return embedding_values + + # Expected lookup result, using combiner='mean'. + expected_lookups_a = ( + # example 0: + (7., 11.), # ids [2], embedding = [7, 11] + # example 1: + (2., 3.5), # ids [0, 1], embedding = mean([1, 2] + [3, 5]) = [2, 3.5] + ) + expected_lookups_b = ( + # example 0: + (1., 2.), # ids [0], embedding = [1, 2] + # example 1: + (0., 0.), # ids [], embedding = [0, 0] + ) + + # Build columns. + categorical_column_a = fc.categorical_column_with_identity( + key='aaa', num_buckets=vocabulary_size) + categorical_column_b = fc.categorical_column_with_identity( + key='bbb', num_buckets=vocabulary_size) + embedding_column_a, embedding_column_b = fc.shared_embedding_columns( + [categorical_column_a, categorical_column_b], + dimension=embedding_dimension, + initializer=_initializer) + + fc.input_layer( + input_features, [embedding_column_a, embedding_column_b], + weight_collections=('my_vars',)) + + # Assert expected embedding variable and lookups. + global_vars = ops.get_collection(ops.GraphKeys.GLOBAL_VARIABLES) + self.assertItemsEqual( + ('input_layer/aaa_bbb_shared_embedding/embedding_weights:0',), + tuple(v.name for v in global_vars)) + my_vars = ops.get_collection('my_vars') + self.assertItemsEqual( + ('input_layer/aaa_bbb_shared_embedding/embedding_weights:0',), + tuple(v.name for v in my_vars)) + def test_get_dense_tensor_placeholder_inputs(self): # Inputs. vocabulary_size = 3 |