aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/python/feature_column
diff options
context:
space:
mode:
authorGravatar Rohan Jain <rohanj@google.com>2018-05-31 20:54:27 -0700
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2018-05-31 20:57:30 -0700
commit54b20c4be0372fb14ec9a289e4d7de7f67c03ff6 (patch)
tree1dfcd43865d1401c87f40709cb50fd78caeb3fce /tensorflow/python/feature_column
parent8d1d8c1b436b84eeaede95c6ed53308a8a97cb08 (diff)
Making sure that weight_collections are respected for shared_embedding_columns
PiperOrigin-RevId: 198823349
Diffstat (limited to 'tensorflow/python/feature_column')
-rw-r--r--tensorflow/python/feature_column/feature_column.py11
-rw-r--r--tensorflow/python/feature_column/feature_column_test.py66
2 files changed, 77 insertions, 0 deletions
diff --git a/tensorflow/python/feature_column/feature_column.py b/tensorflow/python/feature_column/feature_column.py
index 7aa46af828..59801efc26 100644
--- a/tensorflow/python/feature_column/feature_column.py
+++ b/tensorflow/python/feature_column/feature_column.py
@@ -1799,6 +1799,15 @@ class _EmbeddingColumnLayer(base.Layer):
self._initializer = initializer
self._weight_collections = weight_collections
+ def set_weight_collections(self, weight_collections):
+ """Sets the weight collections for the layer.
+
+ Args:
+ weight_collections: A list of collection names to which the Variable will
+ be added.
+ """
+ self._weight_collections = weight_collections
+
def build(self, _):
self._embedding_weight_var = self.add_variable(
name='embedding_weights',
@@ -2604,6 +2613,7 @@ class _SharedEmbeddingColumn(
sparse_ids = sparse_tensors.id_tensor
sparse_weights = sparse_tensors.weight_tensor
+ self._layer.set_weight_collections(weight_collections)
embedding_weights = self._layer(
None, scope=variable_scope.get_variable_scope())
# If we're in graph mode and this is called with a different graph,
@@ -2612,6 +2622,7 @@ class _SharedEmbeddingColumn(
ops.get_default_graph() !=
_get_graph_for_variable(embedding_weights)):
self._reset_config()
+ self._layer.set_weight_collections(weight_collections)
embedding_weights = self._layer(
None, scope=variable_scope.get_variable_scope())
diff --git a/tensorflow/python/feature_column/feature_column_test.py b/tensorflow/python/feature_column/feature_column_test.py
index 0af7b9baa9..627430d6bc 100644
--- a/tensorflow/python/feature_column/feature_column_test.py
+++ b/tensorflow/python/feature_column/feature_column_test.py
@@ -5615,6 +5615,72 @@ class SharedEmbeddingColumnTest(test.TestCase):
self.assertAllEqual(expected_lookups_a, embedding_lookup_a.eval())
self.assertAllEqual(expected_lookups_b, embedding_lookup_b.eval())
+ def test_get_dense_tensor_weight_collections(self):
+ # Inputs.
+ vocabulary_size = 3
+ # -1 values are ignored.
+ input_a = np.array([
+ [2, -1, -1], # example 0, ids [2]
+ [0, 1, -1]
+ ]) # example 1, ids [0, 1]
+ input_b = np.array([
+ [0, -1, -1], # example 0, ids [0]
+ [-1, -1, -1]
+ ]) # example 1, ids []
+ input_features = {'aaa': input_a, 'bbb': input_b}
+
+ # Embedding variable.
+ embedding_dimension = 2
+ embedding_values = (
+ (1., 2.), # id 0
+ (3., 5.), # id 1
+ (7., 11.) # id 2
+ )
+
+ def _initializer(shape, dtype, partition_info):
+ self.assertAllEqual((vocabulary_size, embedding_dimension), shape)
+ self.assertEqual(dtypes.float32, dtype)
+ self.assertIsNone(partition_info)
+ return embedding_values
+
+ # Expected lookup result, using combiner='mean'.
+ expected_lookups_a = (
+ # example 0:
+ (7., 11.), # ids [2], embedding = [7, 11]
+ # example 1:
+ (2., 3.5), # ids [0, 1], embedding = mean([1, 2] + [3, 5]) = [2, 3.5]
+ )
+ expected_lookups_b = (
+ # example 0:
+ (1., 2.), # ids [0], embedding = [1, 2]
+ # example 1:
+ (0., 0.), # ids [], embedding = [0, 0]
+ )
+
+ # Build columns.
+ categorical_column_a = fc.categorical_column_with_identity(
+ key='aaa', num_buckets=vocabulary_size)
+ categorical_column_b = fc.categorical_column_with_identity(
+ key='bbb', num_buckets=vocabulary_size)
+ embedding_column_a, embedding_column_b = fc.shared_embedding_columns(
+ [categorical_column_a, categorical_column_b],
+ dimension=embedding_dimension,
+ initializer=_initializer)
+
+ fc.input_layer(
+ input_features, [embedding_column_a, embedding_column_b],
+ weight_collections=('my_vars',))
+
+ # Assert expected embedding variable and lookups.
+ global_vars = ops.get_collection(ops.GraphKeys.GLOBAL_VARIABLES)
+ self.assertItemsEqual(
+ ('input_layer/aaa_bbb_shared_embedding/embedding_weights:0',),
+ tuple(v.name for v in global_vars))
+ my_vars = ops.get_collection('my_vars')
+ self.assertItemsEqual(
+ ('input_layer/aaa_bbb_shared_embedding/embedding_weights:0',),
+ tuple(v.name for v in my_vars))
+
def test_get_dense_tensor_placeholder_inputs(self):
# Inputs.
vocabulary_size = 3