diff options
author | 2018-06-13 18:01:03 -0700 | |
---|---|---|
committer | 2018-06-13 18:03:50 -0700 | |
commit | 462a7e063169010899ce0fa9534f6d7c980f1116 (patch) | |
tree | bd55a540786567459d8a6d756bcb21fd40079839 /tensorflow/contrib/feature_column/python/feature_column/sequence_feature_column_test.py | |
parent | 1babacb30c63e7a5231c3aaaac79bc56f68bf3ec (diff) |
Add sequential functionality to _SharedEmbeddingColumn.
PiperOrigin-RevId: 200485876
Diffstat (limited to 'tensorflow/contrib/feature_column/python/feature_column/sequence_feature_column_test.py')
-rw-r--r-- | tensorflow/contrib/feature_column/python/feature_column/sequence_feature_column_test.py | 279 |
1 files changed, 278 insertions, 1 deletions
diff --git a/tensorflow/contrib/feature_column/python/feature_column/sequence_feature_column_test.py b/tensorflow/contrib/feature_column/python/feature_column/sequence_feature_column_test.py index 88f5d53516..ee74cf56dc 100644 --- a/tensorflow/contrib/feature_column/python/feature_column/sequence_feature_column_test.py +++ b/tensorflow/contrib/feature_column/python/feature_column/sequence_feature_column_test.py @@ -109,7 +109,7 @@ class SequenceInputLayerTest(test.TestCase): expected_sequence_length, sequence_length.eval(session=sess)) def test_embedding_column_with_non_sequence_categorical(self): - """Tests that error is raised for non-sequence categorical column.""" + """Tests that error is raised for non-sequence embedding column.""" vocabulary_size = 3 sparse_input = sparse_tensor.SparseTensorValue( # example 0, ids [2] @@ -131,6 +131,107 @@ class SequenceInputLayerTest(test.TestCase): features={'aaa': sparse_input}, feature_columns=[embedding_column_a]) + def test_shared_embedding_column(self): + vocabulary_size = 3 + sparse_input_a = sparse_tensor.SparseTensorValue( + # example 0, ids [2] + # example 1, ids [0, 1] + indices=((0, 0), (1, 0), (1, 1)), + values=(2, 0, 1), + dense_shape=(2, 2)) + sparse_input_b = sparse_tensor.SparseTensorValue( + # example 0, ids [1] + # example 1, ids [2, 0] + indices=((0, 0), (1, 0), (1, 1)), + values=(1, 2, 0), + dense_shape=(2, 2)) + + embedding_dimension = 2 + embedding_values = ( + (1., 2.), # id 0 + (3., 4.), # id 1 + (5., 6.) # id 2 + ) + + def _get_initializer(embedding_dimension, embedding_values): + + def _initializer(shape, dtype, partition_info): + self.assertAllEqual((vocabulary_size, embedding_dimension), shape) + self.assertEqual(dtypes.float32, dtype) + self.assertIsNone(partition_info) + return embedding_values + + return _initializer + + expected_input_layer = [ + # example 0, ids_a [2], ids_b [1] + [[5., 6., 3., 4.], [0., 0., 0., 0.]], + # example 1, ids_a [0, 1], ids_b [2, 0] + [[1., 2., 5., 6.], [3., 4., 1., 2.]], + ] + expected_sequence_length = [1, 2] + + categorical_column_a = sfc.sequence_categorical_column_with_identity( + key='aaa', num_buckets=vocabulary_size) + categorical_column_b = sfc.sequence_categorical_column_with_identity( + key='bbb', num_buckets=vocabulary_size) + # Test that columns are reordered alphabetically. + shared_embedding_columns = fc.shared_embedding_columns( + [categorical_column_b, categorical_column_a], + dimension=embedding_dimension, + initializer=_get_initializer(embedding_dimension, embedding_values)) + + input_layer, sequence_length = sfc.sequence_input_layer( + features={ + 'aaa': sparse_input_a, + 'bbb': sparse_input_b, + }, + feature_columns=shared_embedding_columns) + + global_vars = ops.get_collection(ops.GraphKeys.GLOBAL_VARIABLES) + self.assertItemsEqual( + ('sequence_input_layer/aaa_bbb_shared_embedding/embedding_weights:0',), + tuple([v.name for v in global_vars])) + with monitored_session.MonitoredSession() as sess: + self.assertAllEqual(embedding_values, global_vars[0].eval(session=sess)) + self.assertAllEqual(expected_input_layer, input_layer.eval(session=sess)) + self.assertAllEqual( + expected_sequence_length, sequence_length.eval(session=sess)) + + def test_shared_embedding_column_with_non_sequence_categorical(self): + """Tests that error is raised for non-sequence shared embedding column.""" + vocabulary_size = 3 + sparse_input_a = sparse_tensor.SparseTensorValue( + # example 0, ids [2] + # example 1, ids [0, 1] + indices=((0, 0), (1, 0), (1, 1)), + values=(2, 0, 1), + dense_shape=(2, 2)) + sparse_input_b = sparse_tensor.SparseTensorValue( + # example 0, ids [2] + # example 1, ids [0, 1] + indices=((0, 0), (1, 0), (1, 1)), + values=(2, 0, 1), + dense_shape=(2, 2)) + + categorical_column_a = fc.categorical_column_with_identity( + key='aaa', num_buckets=vocabulary_size) + categorical_column_b = fc.categorical_column_with_identity( + key='bbb', num_buckets=vocabulary_size) + shared_embedding_columns = fc.shared_embedding_columns( + [categorical_column_a, categorical_column_b], dimension=2) + + with self.assertRaisesRegexp( + ValueError, + r'In embedding_column: aaa_shared_embedding\. categorical_column must ' + r'be of type _SequenceCategoricalColumn to use sequence_input_layer\.'): + _, _ = sfc.sequence_input_layer( + features={ + 'aaa': sparse_input_a, + 'bbb': sparse_input_b + }, + feature_columns=shared_embedding_columns) + def test_indicator_column(self): vocabulary_size_a = 3 sparse_input_a = sparse_tensor.SparseTensorValue( @@ -577,6 +678,182 @@ class SequenceEmbeddingColumnTest(test.TestCase): expected_sequence_length, sequence_length.eval(session=sess)) +class SequenceSharedEmbeddingColumnTest(test.TestCase): + + def test_get_sequence_dense_tensor(self): + vocabulary_size = 3 + embedding_dimension = 2 + embedding_values = ( + (1., 2.), # id 0 + (3., 5.), # id 1 + (7., 11.) # id 2 + ) + + def _initializer(shape, dtype, partition_info): + self.assertAllEqual((vocabulary_size, embedding_dimension), shape) + self.assertEqual(dtypes.float32, dtype) + self.assertIsNone(partition_info) + return embedding_values + + sparse_input_a = sparse_tensor.SparseTensorValue( + # example 0, ids [2] + # example 1, ids [0, 1] + # example 2, ids [] + # example 3, ids [1] + indices=((0, 0), (1, 0), (1, 1), (3, 0)), + values=(2, 0, 1, 1), + dense_shape=(4, 2)) + sparse_input_b = sparse_tensor.SparseTensorValue( + # example 0, ids [1] + # example 1, ids [0, 2] + # example 2, ids [0] + # example 3, ids [] + indices=((0, 0), (1, 0), (1, 1), (2, 0)), + values=(1, 0, 2, 0), + dense_shape=(4, 2)) + + expected_lookups_a = [ + # example 0, ids [2] + [[7., 11.], [0., 0.]], + # example 1, ids [0, 1] + [[1., 2.], [3., 5.]], + # example 2, ids [] + [[0., 0.], [0., 0.]], + # example 3, ids [1] + [[3., 5.], [0., 0.]], + ] + + expected_lookups_b = [ + # example 0, ids [1] + [[3., 5.], [0., 0.]], + # example 1, ids [0, 2] + [[1., 2.], [7., 11.]], + # example 2, ids [0] + [[1., 2.], [0., 0.]], + # example 3, ids [] + [[0., 0.], [0., 0.]], + ] + + categorical_column_a = sfc.sequence_categorical_column_with_identity( + key='aaa', num_buckets=vocabulary_size) + categorical_column_b = sfc.sequence_categorical_column_with_identity( + key='bbb', num_buckets=vocabulary_size) + shared_embedding_columns = fc.shared_embedding_columns( + [categorical_column_a, categorical_column_b], + dimension=embedding_dimension, + initializer=_initializer) + + embedding_lookup_a = shared_embedding_columns[0]._get_sequence_dense_tensor( + _LazyBuilder({ + 'aaa': sparse_input_a + }))[0] + embedding_lookup_b = shared_embedding_columns[1]._get_sequence_dense_tensor( + _LazyBuilder({ + 'bbb': sparse_input_b + }))[0] + + global_vars = ops.get_collection(ops.GraphKeys.GLOBAL_VARIABLES) + self.assertItemsEqual(('embedding_weights:0',), + tuple([v.name for v in global_vars])) + with monitored_session.MonitoredSession() as sess: + self.assertAllEqual(embedding_values, global_vars[0].eval(session=sess)) + self.assertAllEqual( + expected_lookups_a, embedding_lookup_a.eval(session=sess)) + self.assertAllEqual( + expected_lookups_b, embedding_lookup_b.eval(session=sess)) + + def test_sequence_length(self): + vocabulary_size = 3 + + sparse_input_a = sparse_tensor.SparseTensorValue( + # example 0, ids [2] + # example 1, ids [0, 1] + indices=((0, 0), (1, 0), (1, 1)), + values=(2, 0, 1), + dense_shape=(2, 2)) + expected_sequence_length_a = [1, 2] + categorical_column_a = sfc.sequence_categorical_column_with_identity( + key='aaa', num_buckets=vocabulary_size) + + sparse_input_b = sparse_tensor.SparseTensorValue( + # example 0, ids [0, 2] + # example 1, ids [1] + indices=((0, 0), (0, 1), (1, 0)), + values=(0, 2, 1), + dense_shape=(2, 2)) + expected_sequence_length_b = [2, 1] + categorical_column_b = sfc.sequence_categorical_column_with_identity( + key='bbb', num_buckets=vocabulary_size) + shared_embedding_columns = fc.shared_embedding_columns( + [categorical_column_a, categorical_column_b], dimension=2) + + sequence_length_a = shared_embedding_columns[0]._get_sequence_dense_tensor( + _LazyBuilder({ + 'aaa': sparse_input_a + }))[1] + sequence_length_b = shared_embedding_columns[1]._get_sequence_dense_tensor( + _LazyBuilder({ + 'bbb': sparse_input_b + }))[1] + + with monitored_session.MonitoredSession() as sess: + sequence_length_a = sess.run(sequence_length_a) + self.assertAllEqual(expected_sequence_length_a, sequence_length_a) + self.assertEqual(np.int64, sequence_length_a.dtype) + sequence_length_b = sess.run(sequence_length_b) + self.assertAllEqual(expected_sequence_length_b, sequence_length_b) + self.assertEqual(np.int64, sequence_length_b.dtype) + + def test_sequence_length_with_empty_rows(self): + """Tests _sequence_length when some examples do not have ids.""" + vocabulary_size = 3 + sparse_input_a = sparse_tensor.SparseTensorValue( + # example 0, ids [] + # example 1, ids [2] + # example 2, ids [0, 1] + # example 3, ids [] + # example 4, ids [1] + # example 5, ids [] + indices=((1, 0), (2, 0), (2, 1), (4, 0)), + values=(2, 0, 1, 1), + dense_shape=(6, 2)) + expected_sequence_length_a = [0, 1, 2, 0, 1, 0] + categorical_column_a = sfc.sequence_categorical_column_with_identity( + key='aaa', num_buckets=vocabulary_size) + + sparse_input_b = sparse_tensor.SparseTensorValue( + # example 0, ids [2] + # example 1, ids [] + # example 2, ids [] + # example 3, ids [] + # example 4, ids [1] + # example 5, ids [0, 1] + indices=((0, 0), (4, 0), (5, 0), (5, 1)), + values=(2, 1, 0, 1), + dense_shape=(6, 2)) + expected_sequence_length_b = [1, 0, 0, 0, 1, 2] + categorical_column_b = sfc.sequence_categorical_column_with_identity( + key='bbb', num_buckets=vocabulary_size) + + shared_embedding_columns = fc.shared_embedding_columns( + [categorical_column_a, categorical_column_b], dimension=2) + + sequence_length_a = shared_embedding_columns[0]._get_sequence_dense_tensor( + _LazyBuilder({ + 'aaa': sparse_input_a + }))[1] + sequence_length_b = shared_embedding_columns[1]._get_sequence_dense_tensor( + _LazyBuilder({ + 'bbb': sparse_input_b + }))[1] + + with monitored_session.MonitoredSession() as sess: + self.assertAllEqual( + expected_sequence_length_a, sequence_length_a.eval(session=sess)) + self.assertAllEqual( + expected_sequence_length_b, sequence_length_b.eval(session=sess)) + + class SequenceIndicatorColumnTest(test.TestCase): def test_get_sequence_dense_tensor(self): |