From 153decedefc8da1fbd0717f4223b4b053e7aa517 Mon Sep 17 00:00:00 2001 From: Karmel Allison Date: Mon, 8 Oct 2018 10:36:38 -0700 Subject: Add support for SequenceExamples to sequence_feature_columns PiperOrigin-RevId: 216210141 --- .../contrib/estimator/python/estimator/rnn.py | 54 +- tensorflow/contrib/feature_column/BUILD | 21 + .../feature_column/sequence_feature_column.py | 72 +- .../sequence_feature_column_integration_test.py | 280 +++++++ .../feature_column/sequence_feature_column_test.py | 912 ++++++++++++++------- 5 files changed, 980 insertions(+), 359 deletions(-) create mode 100644 tensorflow/contrib/feature_column/python/feature_column/sequence_feature_column_integration_test.py (limited to 'tensorflow/contrib') diff --git a/tensorflow/contrib/estimator/python/estimator/rnn.py b/tensorflow/contrib/estimator/python/estimator/rnn.py index 98660bb731..c595f47395 100644 --- a/tensorflow/contrib/estimator/python/estimator/rnn.py +++ b/tensorflow/contrib/estimator/python/estimator/rnn.py @@ -30,7 +30,6 @@ from tensorflow.python.framework import dtypes from tensorflow.python.framework import ops from tensorflow.python.layers import core as core_layers from tensorflow.python.ops import array_ops -from tensorflow.python.ops import check_ops from tensorflow.python.ops import init_ops from tensorflow.python.ops import math_ops from tensorflow.python.ops import partitioned_variables @@ -92,55 +91,6 @@ def _make_rnn_cell_fn(num_units, cell_type='basic_rnn'): return rnn_cell_fn -def _concatenate_context_input(sequence_input, context_input): - """Replicates `context_input` across all timesteps of `sequence_input`. - - Expands dimension 1 of `context_input` then tiles it `sequence_length` times. - This value is appended to `sequence_input` on dimension 2 and the result is - returned. - - Args: - sequence_input: A `Tensor` of dtype `float32` and shape `[batch_size, - padded_length, d0]`. - context_input: A `Tensor` of dtype `float32` and shape `[batch_size, d1]`. - - Returns: - A `Tensor` of dtype `float32` and shape `[batch_size, padded_length, - d0 + d1]`. - - Raises: - ValueError: If `sequence_input` does not have rank 3 or `context_input` does - not have rank 2. - """ - seq_rank_check = check_ops.assert_rank( - sequence_input, - 3, - message='sequence_input must have rank 3', - data=[array_ops.shape(sequence_input)]) - seq_type_check = check_ops.assert_type( - sequence_input, - dtypes.float32, - message='sequence_input must have dtype float32; got {}.'.format( - sequence_input.dtype)) - ctx_rank_check = check_ops.assert_rank( - context_input, - 2, - message='context_input must have rank 2', - data=[array_ops.shape(context_input)]) - ctx_type_check = check_ops.assert_type( - context_input, - dtypes.float32, - message='context_input must have dtype float32; got {}.'.format( - context_input.dtype)) - with ops.control_dependencies( - [seq_rank_check, seq_type_check, ctx_rank_check, ctx_type_check]): - padded_length = array_ops.shape(sequence_input)[1] - tiled_context_input = array_ops.tile( - array_ops.expand_dims(context_input, 1), - array_ops.concat([[1], [padded_length], [1]], 0)) - return array_ops.concat([sequence_input, tiled_context_input], 2) - - def _select_last_activations(activations, sequence_lengths): """Selects the nth set of activations for each n in `sequence_length`. @@ -222,8 +172,8 @@ def _rnn_logit_fn_builder(output_units, rnn_cell_fn, sequence_feature_columns, context_input = feature_column_lib.input_layer( features=features, feature_columns=context_feature_columns) - sequence_input = _concatenate_context_input(sequence_input, - context_input) + sequence_input = seq_fc.concatenate_context_input( + context_input, sequence_input) cell = rnn_cell_fn(mode) # Ignore output state. diff --git a/tensorflow/contrib/feature_column/BUILD b/tensorflow/contrib/feature_column/BUILD index aab7d0c9e8..a926ffd598 100644 --- a/tensorflow/contrib/feature_column/BUILD +++ b/tensorflow/contrib/feature_column/BUILD @@ -27,6 +27,7 @@ py_library( "//tensorflow/python:check_ops", "//tensorflow/python:dtypes", "//tensorflow/python:framework_ops", + "//tensorflow/python:math_ops", "//tensorflow/python:parsing_ops", "//tensorflow/python:sparse_ops", "//tensorflow/python:tensor_shape", @@ -46,9 +47,29 @@ py_test( "//tensorflow/python:dtypes", "//tensorflow/python:errors", "//tensorflow/python:framework_ops", + "//tensorflow/python:math_ops", + "//tensorflow/python:parsing_ops", "//tensorflow/python:sparse_tensor", "//tensorflow/python:training", "//tensorflow/python/feature_column", "//third_party/py/numpy", + "@absl_py//absl/testing:parameterized", + ], +) + +py_test( + name = "sequence_feature_column_integration_test", + srcs = ["python/feature_column/sequence_feature_column_integration_test.py"], + srcs_version = "PY2AND3", + tags = ["no_pip"], + deps = [ + ":sequence_feature_column", + "//tensorflow/python:client_testlib", + "//tensorflow/python:framework_ops", + "//tensorflow/python:parsing_ops", + "//tensorflow/python:training", + "//tensorflow/python:util", + "//tensorflow/python/feature_column", + "//tensorflow/python/keras:layers", ], ) diff --git a/tensorflow/contrib/feature_column/python/feature_column/sequence_feature_column.py b/tensorflow/contrib/feature_column/python/feature_column/sequence_feature_column.py index 05bcdac2ca..dd6da35ed0 100644 --- a/tensorflow/contrib/feature_column/python/feature_column/sequence_feature_column.py +++ b/tensorflow/contrib/feature_column/python/feature_column/sequence_feature_column.py @@ -33,7 +33,6 @@ from tensorflow.python.ops import sparse_ops from tensorflow.python.ops import variable_scope # pylint: disable=protected-access -# TODO(b/73827486): Support SequenceExample. def sequence_input_layer( @@ -110,6 +109,7 @@ def sequence_input_layer( output_tensors = [] sequence_lengths = [] ordered_columns = [] + for column in sorted(feature_columns, key=lambda x: x.name): ordered_columns.append(column) with variable_scope.variable_scope( @@ -121,17 +121,67 @@ def sequence_input_layer( # Flattens the final dimension to produce a 3D Tensor. num_elements = column._variable_shape.num_elements() shape = array_ops.shape(dense_tensor) + target_shape = [shape[0], shape[1], num_elements] output_tensors.append( - array_ops.reshape( - dense_tensor, - shape=array_ops.concat([shape[:2], [num_elements]], axis=0))) + array_ops.reshape(dense_tensor, shape=target_shape)) sequence_lengths.append(sequence_length) + fc._verify_static_batch_size_equality(output_tensors, ordered_columns) fc._verify_static_batch_size_equality(sequence_lengths, ordered_columns) sequence_length = _assert_all_equal_and_return(sequence_lengths) + return array_ops.concat(output_tensors, -1), sequence_length +def concatenate_context_input(context_input, sequence_input): + """Replicates `context_input` across all timesteps of `sequence_input`. + + Expands dimension 1 of `context_input` then tiles it `sequence_length` times. + This value is appended to `sequence_input` on dimension 2 and the result is + returned. + + Args: + context_input: A `Tensor` of dtype `float32` and shape `[batch_size, d1]`. + sequence_input: A `Tensor` of dtype `float32` and shape `[batch_size, + padded_length, d0]`. + + Returns: + A `Tensor` of dtype `float32` and shape `[batch_size, padded_length, + d0 + d1]`. + + Raises: + ValueError: If `sequence_input` does not have rank 3 or `context_input` does + not have rank 2. + """ + seq_rank_check = check_ops.assert_rank( + sequence_input, + 3, + message='sequence_input must have rank 3', + data=[array_ops.shape(sequence_input)]) + seq_type_check = check_ops.assert_type( + sequence_input, + dtypes.float32, + message='sequence_input must have dtype float32; got {}.'.format( + sequence_input.dtype)) + ctx_rank_check = check_ops.assert_rank( + context_input, + 2, + message='context_input must have rank 2', + data=[array_ops.shape(context_input)]) + ctx_type_check = check_ops.assert_type( + context_input, + dtypes.float32, + message='context_input must have dtype float32; got {}.'.format( + context_input.dtype)) + with ops.control_dependencies( + [seq_rank_check, seq_type_check, ctx_rank_check, ctx_type_check]): + padded_length = array_ops.shape(sequence_input)[1] + tiled_context_input = array_ops.tile( + array_ops.expand_dims(context_input, 1), + array_ops.concat([[1], [padded_length], [1]], 0)) + return array_ops.concat([sequence_input, tiled_context_input], 2) + + def sequence_categorical_column_with_identity( key, num_buckets, default_value=None): """Returns a feature column that represents sequences of integers. @@ -453,9 +503,17 @@ class _SequenceNumericColumn( [array_ops.shape(dense_tensor)[:1], [-1], self._variable_shape], axis=0) dense_tensor = array_ops.reshape(dense_tensor, shape=dense_shape) - sequence_length = fc._sequence_length_from_sparse_tensor( - sp_tensor, num_elements=self._variable_shape.num_elements()) + + # Get the number of timesteps per example + # For the 2D case, the raw values are grouped according to num_elements; + # for the 3D case, the grouping happens in the third dimension, and + # sequence length is not affected. + num_elements = (self._variable_shape.num_elements() + if sp_tensor.shape.ndims == 2 else 1) + seq_length = fc._sequence_length_from_sparse_tensor( + sp_tensor, num_elements=num_elements) + return fc._SequenceDenseColumn.TensorSequenceLengthPair( - dense_tensor=dense_tensor, sequence_length=sequence_length) + dense_tensor=dense_tensor, sequence_length=seq_length) # pylint: enable=protected-access diff --git a/tensorflow/contrib/feature_column/python/feature_column/sequence_feature_column_integration_test.py b/tensorflow/contrib/feature_column/python/feature_column/sequence_feature_column_integration_test.py new file mode 100644 index 0000000000..d8ca363627 --- /dev/null +++ b/tensorflow/contrib/feature_column/python/feature_column/sequence_feature_column_integration_test.py @@ -0,0 +1,280 @@ +# Copyright 2018 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Integration test for sequence feature columns with SequenceExamples.""" + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import string +import tempfile + +from google.protobuf import text_format + +from tensorflow.contrib.feature_column.python.feature_column import sequence_feature_column as sfc +from tensorflow.core.example import example_pb2 +from tensorflow.core.example import feature_pb2 +from tensorflow.python.data.ops import dataset_ops +from tensorflow.python.feature_column import feature_column as fc +from tensorflow.python.keras.layers import recurrent +from tensorflow.python.ops import parsing_ops +from tensorflow.python.ops import variables +from tensorflow.python.platform import test +from tensorflow.python.util import compat + + +class SequenceFeatureColumnIntegrationTest(test.TestCase): + + def _make_sequence_example(self): + example = example_pb2.SequenceExample() + example.context.feature['int_ctx'].int64_list.value.extend([5]) + example.context.feature['float_ctx'].float_list.value.extend([123.6]) + for val in range(0, 10, 2): + feat = feature_pb2.Feature() + feat.int64_list.value.extend([val] * val) + example.feature_lists.feature_list['int_list'].feature.extend([feat]) + for val in range(1, 11, 2): + feat = feature_pb2.Feature() + feat.bytes_list.value.extend([compat.as_bytes(str(val))] * val) + example.feature_lists.feature_list['str_list'].feature.extend([feat]) + + return example + + def _build_feature_columns(self): + col = fc.categorical_column_with_identity( + 'int_ctx', num_buckets=100) + ctx_cols = [ + fc.embedding_column(col, dimension=10), + fc.numeric_column('float_ctx')] + + identity_col = sfc.sequence_categorical_column_with_identity( + 'int_list', num_buckets=10) + bucket_col = sfc.sequence_categorical_column_with_hash_bucket( + 'bytes_list', hash_bucket_size=100) + seq_cols = [ + fc.embedding_column(identity_col, dimension=10), + fc.embedding_column(bucket_col, dimension=20)] + + return ctx_cols, seq_cols + + def test_sequence_example_into_input_layer(self): + examples = [_make_sequence_example().SerializeToString()] * 100 + ctx_cols, seq_cols = self._build_feature_columns() + + def _parse_example(example): + ctx, seq = parsing_ops.parse_single_sequence_example( + example, + context_features=fc.make_parse_example_spec(ctx_cols), + sequence_features=fc.make_parse_example_spec(seq_cols)) + ctx.update(seq) + return ctx + + ds = dataset_ops.Dataset.from_tensor_slices(examples) + ds = ds.map(_parse_example) + ds = ds.batch(20) + + # Test on a single batch + features = ds.make_one_shot_iterator().get_next() + + # Tile the context features across the sequence features + seq_layer, _ = sfc.sequence_input_layer(features, seq_cols) + ctx_layer = fc.input_layer(features, ctx_cols) + input_layer = sfc.concatenate_context_input(ctx_layer, seq_layer) + + rnn_layer = recurrent.RNN(recurrent.SimpleRNNCell(10)) + output = rnn_layer(input_layer) + + with self.cached_session() as sess: + sess.run(variables.global_variables_initializer()) + features_r = sess.run(features) + self.assertAllEqual(features_r['int_list'].dense_shape, [20, 3, 6]) + + output_r = sess.run(output) + self.assertAllEqual(output_r.shape, [20, 10]) + + +class SequenceExampleParsingTest(test.TestCase): + + def test_seq_ex_in_sequence_categorical_column_with_identity(self): + self._test_parsed_sequence_example( + 'int_list', sfc.sequence_categorical_column_with_identity, + 10, [3, 6], [2, 4, 6]) + + def test_seq_ex_in_sequence_categorical_column_with_hash_bucket(self): + self._test_parsed_sequence_example( + 'bytes_list', sfc.sequence_categorical_column_with_hash_bucket, + 10, [3, 4], [compat.as_bytes(x) for x in 'acg']) + + def test_seq_ex_in_sequence_categorical_column_with_vocabulary_list(self): + self._test_parsed_sequence_example( + 'bytes_list', sfc.sequence_categorical_column_with_vocabulary_list, + list(string.ascii_lowercase), [3, 4], + [compat.as_bytes(x) for x in 'acg']) + + def test_seq_ex_in_sequence_categorical_column_with_vocabulary_file(self): + _, fname = tempfile.mkstemp() + with open(fname, 'w') as f: + f.write(string.ascii_lowercase) + self._test_parsed_sequence_example( + 'bytes_list', sfc.sequence_categorical_column_with_vocabulary_file, + fname, [3, 4], [compat.as_bytes(x) for x in 'acg']) + + def _test_parsed_sequence_example( + self, col_name, col_fn, col_arg, shape, values): + """Helper function to check that each FeatureColumn parses correctly. + + Args: + col_name: string, name to give to the feature column. Should match + the name that the column will parse out of the features dict. + col_fn: function used to create the feature column. For example, + sequence_numeric_column. + col_arg: second arg that the target feature column is expecting. + shape: the expected dense_shape of the feature after parsing into + a SparseTensor. + values: the expected values at index [0, 2, 6] of the feature + after parsing into a SparseTensor. + """ + example = _make_sequence_example() + columns = [ + fc.categorical_column_with_identity('int_ctx', num_buckets=100), + fc.numeric_column('float_ctx'), + col_fn(col_name, col_arg) + ] + context, seq_features = parsing_ops.parse_single_sequence_example( + example.SerializeToString(), + context_features=fc.make_parse_example_spec(columns[:2]), + sequence_features=fc.make_parse_example_spec(columns[2:])) + + with self.cached_session() as sess: + ctx_result, seq_result = sess.run([context, seq_features]) + self.assertEqual(list(seq_result[col_name].dense_shape), shape) + self.assertEqual( + list(seq_result[col_name].values[[0, 2, 6]]), values) + self.assertEqual(list(ctx_result['int_ctx'].dense_shape), [1]) + self.assertEqual(ctx_result['int_ctx'].values[0], 5) + self.assertEqual(list(ctx_result['float_ctx'].shape), [1]) + self.assertAlmostEqual(ctx_result['float_ctx'][0], 123.6, places=1) + + +_SEQ_EX_PROTO = """ +context { + feature { + key: "float_ctx" + value { + float_list { + value: 123.6 + } + } + } + feature { + key: "int_ctx" + value { + int64_list { + value: 5 + } + } + } +} +feature_lists { + feature_list { + key: "bytes_list" + value { + feature { + bytes_list { + value: "a" + } + } + feature { + bytes_list { + value: "b" + value: "c" + } + } + feature { + bytes_list { + value: "d" + value: "e" + value: "f" + value: "g" + } + } + } + } + feature_list { + key: "float_list" + value { + feature { + float_list { + value: 1.0 + } + } + feature { + float_list { + value: 3.0 + value: 3.0 + value: 3.0 + } + } + feature { + float_list { + value: 5.0 + value: 5.0 + value: 5.0 + value: 5.0 + value: 5.0 + } + } + } + } + feature_list { + key: "int_list" + value { + feature { + int64_list { + value: 2 + value: 2 + } + } + feature { + int64_list { + value: 4 + value: 4 + value: 4 + value: 4 + } + } + feature { + int64_list { + value: 6 + value: 6 + value: 6 + value: 6 + value: 6 + value: 6 + } + } + } + } +} +""" + + +def _make_sequence_example(): + example = example_pb2.SequenceExample() + return text_format.Parse(_SEQ_EX_PROTO, example) + + +if __name__ == '__main__': + test.main() diff --git a/tensorflow/contrib/feature_column/python/feature_column/sequence_feature_column_test.py b/tensorflow/contrib/feature_column/python/feature_column/sequence_feature_column_test.py index 45d7b74046..929e83523a 100644 --- a/tensorflow/contrib/feature_column/python/feature_column/sequence_feature_column_test.py +++ b/tensorflow/contrib/feature_column/python/feature_column/sequence_feature_column_test.py @@ -19,6 +19,7 @@ from __future__ import division from __future__ import print_function import os +from absl.testing import parameterized import numpy as np from tensorflow.contrib.feature_column.python.feature_column import sequence_feature_column as sfc @@ -28,28 +29,61 @@ from tensorflow.python.framework import dtypes from tensorflow.python.framework import errors from tensorflow.python.framework import ops from tensorflow.python.framework import sparse_tensor +from tensorflow.python.ops import math_ops from tensorflow.python.ops import sparse_ops from tensorflow.python.platform import test from tensorflow.python.training import monitored_session -class SequenceInputLayerTest(test.TestCase): +class SequenceInputLayerTest(test.TestCase, parameterized.TestCase): + + @parameterized.named_parameters( + {'testcase_name': '2D', + 'sparse_input_a': sparse_tensor.SparseTensorValue( + # example 0, ids [2] + # example 1, ids [0, 1] + indices=((0, 0), (1, 0), (1, 1)), + values=(2, 0, 1), + dense_shape=(2, 2)), + 'sparse_input_b': sparse_tensor.SparseTensorValue( + # example 0, ids [1] + # example 1, ids [2, 0] + indices=((0, 0), (1, 0), (1, 1)), + values=(1, 2, 0), + dense_shape=(2, 2)), + 'expected_input_layer': [ + # example 0, ids_a [2], ids_b [1] + [[5., 6., 14., 15., 16.], [0., 0., 0., 0., 0.]], + # example 1, ids_a [0, 1], ids_b [2, 0] + [[1., 2., 17., 18., 19.], [3., 4., 11., 12., 13.]],], + 'expected_sequence_length': [1, 2]}, + {'testcase_name': '3D', + 'sparse_input_a': sparse_tensor.SparseTensorValue( + # feature 0, ids [[2], [0, 1]] + # feature 1, ids [[0, 0], [1]] + indices=( + (0, 0, 0), (0, 1, 0), (0, 1, 1), + (1, 0, 0), (1, 0, 1), (1, 1, 0)), + values=(2, 0, 1, 0, 0, 1), + dense_shape=(2, 2, 2)), + 'sparse_input_b': sparse_tensor.SparseTensorValue( + # feature 0, ids [[1, 1], [1]] + # feature 1, ids [[2], [0]] + indices=((0, 0, 0), (0, 0, 1), (0, 1, 0), (1, 0, 0), (1, 1, 0)), + values=(1, 1, 1, 2, 0), + dense_shape=(2, 2, 2)), + 'expected_input_layer': [ + # feature 0, [a: 2, -, b: 1, 1], [a: 0, 1, b: 1, -] + [[5., 6., 14., 15., 16.], [2., 3., 14., 15., 16.]], + # feature 1, [a: 0, 0, b: 2, -], [a: 1, -, b: 0, -] + [[1., 2., 17., 18., 19.], [3., 4., 11., 12., 13.]]], + 'expected_sequence_length': [2, 2]}, + ) + def test_embedding_column( + self, sparse_input_a, sparse_input_b, expected_input_layer, + expected_sequence_length): - def test_embedding_column(self): vocabulary_size = 3 - sparse_input_a = sparse_tensor.SparseTensorValue( - # example 0, ids [2] - # example 1, ids [0, 1] - indices=((0, 0), (1, 0), (1, 1)), - values=(2, 0, 1), - dense_shape=(2, 2)) - sparse_input_b = sparse_tensor.SparseTensorValue( - # example 0, ids [1] - # example 1, ids [2, 0] - indices=((0, 0), (1, 0), (1, 1)), - values=(1, 2, 0), - dense_shape=(2, 2)) - embedding_dimension_a = 2 embedding_values_a = ( (1., 2.), # id 0 @@ -70,14 +104,6 @@ class SequenceInputLayerTest(test.TestCase): return embedding_values return _initializer - expected_input_layer = [ - # example 0, ids_a [2], ids_b [1] - [[5., 6., 14., 15., 16.], [0., 0., 0., 0., 0.]], - # example 1, ids_a [0, 1], ids_b [2, 0] - [[1., 2., 17., 18., 19.], [3., 4., 11., 12., 13.]], - ] - expected_sequence_length = [1, 2] - categorical_column_a = sfc.sequence_categorical_column_with_identity( key='aaa', num_buckets=vocabulary_size) embedding_column_a = fc.embedding_column( @@ -233,29 +259,53 @@ class SequenceInputLayerTest(test.TestCase): }, feature_columns=shared_embedding_columns) - def test_indicator_column(self): + @parameterized.named_parameters( + {'testcase_name': '2D', + 'sparse_input_a': sparse_tensor.SparseTensorValue( + # example 0, ids [2] + # example 1, ids [0, 1] + indices=((0, 0), (1, 0), (1, 1)), + values=(2, 0, 1), + dense_shape=(2, 2)), + 'sparse_input_b': sparse_tensor.SparseTensorValue( + # example 0, ids [1] + # example 1, ids [1, 0] + indices=((0, 0), (1, 0), (1, 1)), + values=(1, 1, 0), + dense_shape=(2, 2)), + 'expected_input_layer': [ + # example 0, ids_a [2], ids_b [1] + [[0., 0., 1., 0., 1.], [0., 0., 0., 0., 0.]], + # example 1, ids_a [0, 1], ids_b [1, 0] + [[1., 0., 0., 0., 1.], [0., 1., 0., 1., 0.]]], + 'expected_sequence_length': [1, 2]}, + {'testcase_name': '3D', + 'sparse_input_a': sparse_tensor.SparseTensorValue( + # feature 0, ids [[2], [0, 1]] + # feature 1, ids [[0, 0], [1]] + indices=( + (0, 0, 0), (0, 1, 0), (0, 1, 1), + (1, 0, 0), (1, 0, 1), (1, 1, 0)), + values=(2, 0, 1, 0, 0, 1), + dense_shape=(2, 2, 2)), + 'sparse_input_b': sparse_tensor.SparseTensorValue( + # feature 0, ids [[1, 1], [1]] + # feature 1, ids [[1], [0]] + indices=((0, 0, 0), (0, 0, 1), (0, 1, 0), (1, 0, 0), (1, 1, 0)), + values=(1, 1, 1, 1, 0), + dense_shape=(2, 2, 2)), + 'expected_input_layer': [ + # feature 0, [a: 2, -, b: 1, 1], [a: 0, 1, b: 1, -] + [[0., 0., 1., 0., 2.], [1., 1., 0., 0., 1.]], + # feature 1, [a: 0, 0, b: 1, -], [a: 1, -, b: 0, -] + [[2., 0., 0., 0., 1.], [0., 1., 0., 1., 0.]]], + 'expected_sequence_length': [2, 2]}, + ) + def test_indicator_column( + self, sparse_input_a, sparse_input_b, expected_input_layer, + expected_sequence_length): vocabulary_size_a = 3 - sparse_input_a = sparse_tensor.SparseTensorValue( - # example 0, ids [2] - # example 1, ids [0, 1] - indices=((0, 0), (1, 0), (1, 1)), - values=(2, 0, 1), - dense_shape=(2, 2)) vocabulary_size_b = 2 - sparse_input_b = sparse_tensor.SparseTensorValue( - # example 0, ids [1] - # example 1, ids [1, 0] - indices=((0, 0), (1, 0), (1, 1)), - values=(1, 1, 0), - dense_shape=(2, 2)) - - expected_input_layer = [ - # example 0, ids_a [2], ids_b [1] - [[0., 0., 1., 0., 1.], [0., 0., 0., 0., 0.]], - # example 1, ids_a [0, 1], ids_b [1, 0] - [[1., 0., 0., 0., 1.], [0., 1., 0., 1., 0.]], - ] - expected_sequence_length = [1, 2] categorical_column_a = sfc.sequence_categorical_column_with_identity( key='aaa', num_buckets=vocabulary_size_a) @@ -298,18 +348,32 @@ class SequenceInputLayerTest(test.TestCase): features={'aaa': sparse_input}, feature_columns=[indicator_column_a]) - def test_numeric_column(self): - sparse_input = sparse_tensor.SparseTensorValue( - # example 0, values [[0.], [1]] - # example 1, [[10.]] - indices=((0, 0), (0, 1), (1, 0)), - values=(0., 1., 10.), - dense_shape=(2, 2)) - expected_input_layer = [ - [[0.], [1.]], - [[10.], [0.]], - ] - expected_sequence_length = [2, 1] + @parameterized.named_parameters( + {'testcase_name': '2D', + 'sparse_input': sparse_tensor.SparseTensorValue( + # example 0, values [0., 1] + # example 1, [10.] + indices=((0, 0), (0, 1), (1, 0)), + values=(0., 1., 10.), + dense_shape=(2, 2)), + 'expected_input_layer': [ + [[0.], [1.]], + [[10.], [0.]]], + 'expected_sequence_length': [2, 1]}, + {'testcase_name': '3D', + 'sparse_input': sparse_tensor.SparseTensorValue( + # feature 0, ids [[20, 3], [5]] + # feature 1, ids [[3], [8]] + indices=((0, 0, 0), (0, 0, 1), (0, 1, 0), (1, 0, 0), (1, 1, 0)), + values=(20, 3, 5., 3., 8.), + dense_shape=(2, 2, 2)), + 'expected_input_layer': [ + [[20.], [3.], [5.], [0.]], + [[3.], [0.], [8.], [0.]]], + 'expected_sequence_length': [2, 2]}, + ) + def test_numeric_column( + self, sparse_input, expected_input_layer, expected_sequence_length): numeric_column = sfc.sequence_numeric_column('aaa') input_layer, sequence_length = sfc.sequence_input_layer( @@ -321,21 +385,38 @@ class SequenceInputLayerTest(test.TestCase): self.assertAllEqual( expected_sequence_length, sequence_length.eval(session=sess)) - def test_numeric_column_multi_dim(self): + @parameterized.named_parameters( + {'testcase_name': '2D', + 'sparse_input': sparse_tensor.SparseTensorValue( + # example 0, values [0., 1., 2., 3., 4., 5., 6., 7.] + # example 1, [10., 11., 12., 13.] + indices=((0, 0), (0, 1), (0, 2), (0, 3), (0, 4), (0, 5), (0, 6), + (0, 7), (1, 0), (1, 1), (1, 2), (1, 3)), + values=(0., 1., 2., 3., 4., 5., 6., 7., 10., 11., 12., 13.), + dense_shape=(2, 8)), + 'expected_input_layer': [ + # The output of numeric_column._get_dense_tensor should be flattened. + [[0., 1., 2., 3.], [4., 5., 6., 7.]], + [[10., 11., 12., 13.], [0., 0., 0., 0.]]], + 'expected_sequence_length': [2, 1]}, + {'testcase_name': '3D', + 'sparse_input': sparse_tensor.SparseTensorValue( + # example 0, values [[0., 1., 2., 3.]], [[4., 5., 6., 7.]] + # example 1, [[10., 11., 12., 13.], []] + indices=((0, 0, 0), (0, 0, 1), (0, 0, 2), (0, 0, 3), + (0, 1, 0), (0, 1, 1), (0, 1, 2), (0, 1, 3), + (1, 0, 0), (1, 0, 1), (1, 0, 2), (1, 0, 3)), + values=(0., 1., 2., 3., 4., 5., 6., 7., 10., 11., 12., 13.), + dense_shape=(2, 2, 4)), + 'expected_input_layer': [ + # The output of numeric_column._get_dense_tensor should be flattened. + [[0., 1., 2., 3.], [4., 5., 6., 7.]], + [[10., 11., 12., 13.], [0., 0., 0., 0.]]], + 'expected_sequence_length': [2, 1]}, + ) + def test_numeric_column_multi_dim( + self, sparse_input, expected_input_layer, expected_sequence_length): """Tests sequence_input_layer for multi-dimensional numeric_column.""" - sparse_input = sparse_tensor.SparseTensorValue( - # example 0, values [[[0., 1.], [2., 3.]], [[4., 5.], [6., 7.]]] - # example 1, [[[10., 11.], [12., 13.]]] - indices=((0, 0), (0, 1), (0, 2), (0, 3), (0, 4), (0, 5), (0, 6), (0, 7), - (1, 0), (1, 1), (1, 2), (1, 3)), - values=(0., 1., 2., 3., 4., 5., 6., 7., 10., 11., 12., 13.), - dense_shape=(2, 8)) - # The output of numeric_column._get_dense_tensor should be flattened. - expected_input_layer = [ - [[0., 1., 2., 3.], [4., 5., 6., 7.]], - [[10., 11., 12., 13.], [0., 0., 0., 0.]], - ] - expected_sequence_length = [2, 1] numeric_column = sfc.sequence_numeric_column('aaa', shape=(2, 2)) input_layer, sequence_length = sfc.sequence_input_layer( @@ -377,6 +458,134 @@ class SequenceInputLayerTest(test.TestCase): r'\[y \(sequence_input_layer/bbb/sequence_length:0\) = \] \[1 1\]'): sess.run(sequence_length) + @parameterized.named_parameters( + {'testcase_name': '2D', + 'sparse_input': sparse_tensor.SparseTensorValue( + # example 0, values [[[0., 1.], [2., 3.]], [[4., 5.], [6., 7.]]] + # example 1, [[[10., 11.], [12., 13.]]] + indices=((0, 0), (0, 1), (0, 2), (0, 3), (0, 4), (0, 5), (0, 6), + (0, 7), (1, 0), (1, 1), (1, 2), (1, 3)), + values=(0., 1., 2., 3., 4., 5., 6., 7., 10., 11., 12., 13.), + dense_shape=(2, 8)), + 'expected_shape': [2, 2, 4]}, + {'testcase_name': '3D', + 'sparse_input': sparse_tensor.SparseTensorValue( + # example 0, values [[0., 1., 2., 3.]], [[4., 5., 6., 7.]] + # example 1, [[10., 11., 12., 13.], []] + indices=((0, 0, 0), (0, 0, 1), (0, 0, 2), (0, 0, 3), + (0, 1, 0), (0, 1, 1), (0, 1, 2), (0, 1, 2), + (1, 0, 0), (1, 0, 1), (1, 0, 2), (1, 0, 3)), + values=(0., 1., 2., 3., 4., 5., 6., 7., 10., 11., 12., 13.), + dense_shape=(2, 2, 4)), + 'expected_shape': [2, 2, 4]}, + ) + def test_static_shape_from_tensors_numeric( + self, sparse_input, expected_shape): + """Tests that we return a known static shape when we have one.""" + numeric_column = sfc.sequence_numeric_column('aaa', shape=(2, 2)) + + input_layer, _ = sfc.sequence_input_layer( + features={'aaa': sparse_input}, + feature_columns=[numeric_column]) + shape = input_layer.get_shape() + self.assertEqual(shape, expected_shape) + + @parameterized.named_parameters( + {'testcase_name': '2D', + 'sparse_input': sparse_tensor.SparseTensorValue( + # example 0, ids [2] + # example 1, ids [0, 1] + # example 2, ids [] + # example 3, ids [1] + indices=((0, 0), (1, 0), (1, 1), (3, 0)), + values=(2, 0, 1, 1), + dense_shape=(4, 2)), + 'expected_shape': [4, 2, 3]}, + {'testcase_name': '3D', + 'sparse_input': sparse_tensor.SparseTensorValue( + # example 0, ids [[2]] + # example 1, ids [[0, 1], [2]] + # example 2, ids [] + # example 3, ids [[1], [0, 2]] + indices=((0, 0, 0), (1, 0, 0), (1, 0, 1), (1, 1, 0), + (3, 0, 0), (3, 1, 0), (3, 1, 1)), + values=(2, 0, 1, 2, 1, 0, 2), + dense_shape=(4, 2, 2)), + 'expected_shape': [4, 2, 3]} + ) + def test_static_shape_from_tensors_indicator( + self, sparse_input, expected_shape): + """Tests that we return a known static shape when we have one.""" + categorical_column = sfc.sequence_categorical_column_with_identity( + key='aaa', num_buckets=3) + indicator_column = fc.indicator_column(categorical_column) + + input_layer, _ = sfc.sequence_input_layer( + features={'aaa': sparse_input}, feature_columns=[indicator_column]) + shape = input_layer.get_shape() + self.assertEqual(shape, expected_shape) + + +class ConcatenateContextInputTest(test.TestCase, parameterized.TestCase): + """Tests the utility fn concatenate_context_input.""" + + def test_concatenate_context_input(self): + seq_input = ops.convert_to_tensor(np.arange(12).reshape(2, 3, 2)) + context_input = ops.convert_to_tensor(np.arange(10).reshape(2, 5)) + seq_input = math_ops.cast(seq_input, dtype=dtypes.float32) + context_input = math_ops.cast(context_input, dtype=dtypes.float32) + input_layer = sfc.concatenate_context_input(context_input, seq_input) + + expected = np.array([ + [[0, 1, 0, 1, 2, 3, 4], [2, 3, 0, 1, 2, 3, 4], [4, 5, 0, 1, 2, 3, 4]], + [[6, 7, 5, 6, 7, 8, 9], [8, 9, 5, 6, 7, 8, 9], [10, 11, 5, 6, 7, 8, 9]] + ], dtype=np.float32) + with monitored_session.MonitoredSession() as sess: + output = sess.run(input_layer) + self.assertAllEqual(expected, output) + + @parameterized.named_parameters( + {'testcase_name': 'rank_lt_3', + 'seq_input': ops.convert_to_tensor(np.arange(100).reshape(10, 10))}, + {'testcase_name': 'rank_gt_3', + 'seq_input': ops.convert_to_tensor(np.arange(100).reshape(5, 5, 2, 2))} + ) + def test_sequence_input_throws_error(self, seq_input): + context_input = ops.convert_to_tensor(np.arange(100).reshape(10, 10)) + seq_input = math_ops.cast(seq_input, dtype=dtypes.float32) + context_input = math_ops.cast(context_input, dtype=dtypes.float32) + with self.assertRaisesRegexp(ValueError, 'sequence_input must have rank 3'): + sfc.concatenate_context_input(context_input, seq_input) + + @parameterized.named_parameters( + {'testcase_name': 'rank_lt_2', + 'context_input': ops.convert_to_tensor(np.arange(100))}, + {'testcase_name': 'rank_gt_2', + 'context_input': ops.convert_to_tensor(np.arange(100).reshape(5, 5, 4))} + ) + def test_context_input_throws_error(self, context_input): + seq_input = ops.convert_to_tensor(np.arange(100).reshape(5, 5, 4)) + seq_input = math_ops.cast(seq_input, dtype=dtypes.float32) + context_input = math_ops.cast(context_input, dtype=dtypes.float32) + with self.assertRaisesRegexp(ValueError, 'context_input must have rank 2'): + sfc.concatenate_context_input(context_input, seq_input) + + def test_integer_seq_input_throws_error(self): + seq_input = ops.convert_to_tensor(np.arange(100).reshape(5, 5, 4)) + context_input = ops.convert_to_tensor(np.arange(100).reshape(10, 10)) + context_input = math_ops.cast(context_input, dtype=dtypes.float32) + with self.assertRaisesRegexp( + TypeError, 'sequence_input must have dtype float32'): + sfc.concatenate_context_input(context_input, seq_input) + + def test_integer_context_input_throws_error(self): + seq_input = ops.convert_to_tensor(np.arange(100).reshape(5, 5, 4)) + context_input = ops.convert_to_tensor(np.arange(100).reshape(10, 10)) + seq_input = math_ops.cast(seq_input, dtype=dtypes.float32) + with self.assertRaisesRegexp( + TypeError, 'context_input must have dtype float32'): + sfc.concatenate_context_input(context_input, seq_input) + class InputLayerTest(test.TestCase): """Tests input_layer with sequence feature columns.""" @@ -443,75 +652,79 @@ def _assert_sparse_tensor_indices_shape(test_case, expected, actual): test_case.assertAllEqual(expected.dense_shape, actual.dense_shape) -class SequenceCategoricalColumnWithIdentityTest(test.TestCase): - - def test_get_sparse_tensors(self): - column = sfc.sequence_categorical_column_with_identity( - 'aaa', num_buckets=3) - inputs = sparse_tensor.SparseTensorValue( - indices=((0, 0), (1, 0), (1, 1)), - values=(1, 2, 0), - dense_shape=(2, 2)) - expected_sparse_ids = sparse_tensor.SparseTensorValue( - indices=((0, 0, 0), (1, 0, 0), (1, 1, 0)), - values=np.array((1, 2, 0), dtype=np.int64), - dense_shape=(2, 2, 1)) +class SequenceCategoricalColumnWithIdentityTest( + test.TestCase, parameterized.TestCase): + + @parameterized.named_parameters( + {'testcase_name': '2D', + 'inputs': sparse_tensor.SparseTensorValue( + indices=((0, 0), (1, 0), (1, 1)), + values=(1, 2, 0), + dense_shape=(2, 2)), + 'expected': sparse_tensor.SparseTensorValue( + indices=((0, 0, 0), (1, 0, 0), (1, 1, 0)), + values=np.array((1, 2, 0), dtype=np.int64), + dense_shape=(2, 2, 1))}, + {'testcase_name': '3D', + 'inputs': sparse_tensor.SparseTensorValue( + indices=((0, 0, 2), (1, 0, 0), (1, 2, 0)), + values=(6, 7, 8), + dense_shape=(2, 2, 2)), + 'expected': sparse_tensor.SparseTensorValue( + indices=((0, 0, 2), (1, 0, 0), (1, 2, 0)), + values=(6, 7, 8), + dense_shape=(2, 2, 2))} + ) + def test_get_sparse_tensors(self, inputs, expected): + column = sfc.sequence_categorical_column_with_identity('aaa', num_buckets=9) id_weight_pair = column._get_sparse_tensors(_LazyBuilder({'aaa': inputs})) self.assertIsNone(id_weight_pair.weight_tensor) with monitored_session.MonitoredSession() as sess: _assert_sparse_tensor_value( - self, - expected_sparse_ids, - id_weight_pair.id_tensor.eval(session=sess)) - - def test_get_sparse_tensors_inputs3d(self): - """Tests _get_sparse_tensors when the input is already 3D Tensor.""" - column = sfc.sequence_categorical_column_with_identity( - 'aaa', num_buckets=3) - inputs = sparse_tensor.SparseTensorValue( - indices=((0, 0, 0), (1, 0, 0), (1, 1, 0)), - values=(1, 2, 0), - dense_shape=(2, 2, 1)) - - with self.assertRaisesRegexp( - errors.InvalidArgumentError, - r'Column aaa expected ID tensor of rank 2\.\s*' - r'id_tensor shape:\s*\[2 2 1\]'): - id_weight_pair = column._get_sparse_tensors( - _LazyBuilder({'aaa': inputs})) - with monitored_session.MonitoredSession() as sess: - id_weight_pair.id_tensor.eval(session=sess) - - -class SequenceCategoricalColumnWithHashBucketTest(test.TestCase): - - def test_get_sparse_tensors(self): + self, expected, id_weight_pair.id_tensor.eval(session=sess)) + + +class SequenceCategoricalColumnWithHashBucketTest( + test.TestCase, parameterized.TestCase): + + @parameterized.named_parameters( + {'testcase_name': '2D', + 'inputs': sparse_tensor.SparseTensorValue( + indices=((0, 0), (1, 0), (1, 1)), + values=('omar', 'stringer', 'marlo'), + dense_shape=(2, 2)), + 'expected': sparse_tensor.SparseTensorValue( + indices=((0, 0, 0), (1, 0, 0), (1, 1, 0)), + # Ignored to avoid hash dependence in test. + values=np.array((0, 0, 0), dtype=np.int64), + dense_shape=(2, 2, 1))}, + {'testcase_name': '3D', + 'inputs': sparse_tensor.SparseTensorValue( + indices=((0, 0, 2), (1, 0, 0), (1, 2, 0)), + values=('omar', 'stringer', 'marlo'), + dense_shape=(2, 2, 2)), + 'expected': sparse_tensor.SparseTensorValue( + indices=((0, 0, 2), (1, 0, 0), (1, 2, 0)), + # Ignored to avoid hash dependence in test. + values=np.array((0, 0, 0), dtype=np.int64), + dense_shape=(2, 2, 2))} + ) + def test_get_sparse_tensors(self, inputs, expected): column = sfc.sequence_categorical_column_with_hash_bucket( 'aaa', hash_bucket_size=10) - inputs = sparse_tensor.SparseTensorValue( - indices=((0, 0), (1, 0), (1, 1)), - values=('omar', 'stringer', 'marlo'), - dense_shape=(2, 2)) - - expected_sparse_ids = sparse_tensor.SparseTensorValue( - indices=((0, 0, 0), (1, 0, 0), (1, 1, 0)), - # Ignored to avoid hash dependence in test. - values=np.array((0, 0, 0), dtype=np.int64), - dense_shape=(2, 2, 1)) id_weight_pair = column._get_sparse_tensors(_LazyBuilder({'aaa': inputs})) self.assertIsNone(id_weight_pair.weight_tensor) with monitored_session.MonitoredSession() as sess: _assert_sparse_tensor_indices_shape( - self, - expected_sparse_ids, - id_weight_pair.id_tensor.eval(session=sess)) + self, expected, id_weight_pair.id_tensor.eval(session=sess)) -class SequenceCategoricalColumnWithVocabularyFileTest(test.TestCase): +class SequenceCategoricalColumnWithVocabularyFileTest( + test.TestCase, parameterized.TestCase): def _write_vocab(self, vocab_strings, file_name): vocab_file = os.path.join(self.get_temp_dir(), file_name) @@ -527,68 +740,120 @@ class SequenceCategoricalColumnWithVocabularyFileTest(test.TestCase): 'wire_vocabulary.txt') self._wire_vocabulary_size = 3 - def test_get_sparse_tensors(self): + @parameterized.named_parameters( + {'testcase_name': '2D', + 'inputs': sparse_tensor.SparseTensorValue( + indices=((0, 0), (1, 0), (1, 1)), + values=('marlo', 'skywalker', 'omar'), + dense_shape=(2, 2)), + 'expected': sparse_tensor.SparseTensorValue( + indices=((0, 0, 0), (1, 0, 0), (1, 1, 0)), + values=np.array((2, -1, 0), dtype=np.int64), + dense_shape=(2, 2, 1))}, + {'testcase_name': '3D', + 'inputs': sparse_tensor.SparseTensorValue( + indices=((0, 0, 2), (1, 0, 0), (1, 2, 0)), + values=('omar', 'skywalker', 'marlo'), + dense_shape=(2, 2, 2)), + 'expected': sparse_tensor.SparseTensorValue( + indices=((0, 0, 2), (1, 0, 0), (1, 2, 0)), + values=np.array((0, -1, 2), dtype=np.int64), + dense_shape=(2, 2, 2))} + ) + def test_get_sparse_tensors(self, inputs, expected): column = sfc.sequence_categorical_column_with_vocabulary_file( key='aaa', vocabulary_file=self._wire_vocabulary_file_name, vocabulary_size=self._wire_vocabulary_size) - inputs = sparse_tensor.SparseTensorValue( - indices=((0, 0), (1, 0), (1, 1)), - values=('marlo', 'skywalker', 'omar'), - dense_shape=(2, 2)) - expected_sparse_ids = sparse_tensor.SparseTensorValue( - indices=((0, 0, 0), (1, 0, 0), (1, 1, 0)), - values=np.array((2, -1, 0), dtype=np.int64), - dense_shape=(2, 2, 1)) id_weight_pair = column._get_sparse_tensors(_LazyBuilder({'aaa': inputs})) self.assertIsNone(id_weight_pair.weight_tensor) with monitored_session.MonitoredSession() as sess: _assert_sparse_tensor_value( - self, - expected_sparse_ids, - id_weight_pair.id_tensor.eval(session=sess)) - - -class SequenceCategoricalColumnWithVocabularyListTest(test.TestCase): - - def test_get_sparse_tensors(self): + self, expected, id_weight_pair.id_tensor.eval(session=sess)) + + +class SequenceCategoricalColumnWithVocabularyListTest( + test.TestCase, parameterized.TestCase): + + @parameterized.named_parameters( + {'testcase_name': '2D', + 'inputs': sparse_tensor.SparseTensorValue( + indices=((0, 0), (1, 0), (1, 1)), + values=('marlo', 'skywalker', 'omar'), + dense_shape=(2, 2)), + 'expected': sparse_tensor.SparseTensorValue( + indices=((0, 0, 0), (1, 0, 0), (1, 1, 0)), + values=np.array((2, -1, 0), dtype=np.int64), + dense_shape=(2, 2, 1))}, + {'testcase_name': '3D', + 'inputs': sparse_tensor.SparseTensorValue( + indices=((0, 0, 2), (1, 0, 0), (1, 2, 0)), + values=('omar', 'skywalker', 'marlo'), + dense_shape=(2, 2, 2)), + 'expected': sparse_tensor.SparseTensorValue( + indices=((0, 0, 2), (1, 0, 0), (1, 2, 0)), + values=np.array((0, -1, 2), dtype=np.int64), + dense_shape=(2, 2, 2))} + ) + def test_get_sparse_tensors(self, inputs, expected): column = sfc.sequence_categorical_column_with_vocabulary_list( key='aaa', vocabulary_list=('omar', 'stringer', 'marlo')) - inputs = sparse_tensor.SparseTensorValue( - indices=((0, 0), (1, 0), (1, 1)), - values=('marlo', 'skywalker', 'omar'), - dense_shape=(2, 2)) - expected_sparse_ids = sparse_tensor.SparseTensorValue( - indices=((0, 0, 0), (1, 0, 0), (1, 1, 0)), - values=np.array((2, -1, 0), dtype=np.int64), - dense_shape=(2, 2, 1)) id_weight_pair = column._get_sparse_tensors(_LazyBuilder({'aaa': inputs})) self.assertIsNone(id_weight_pair.weight_tensor) with monitored_session.MonitoredSession() as sess: _assert_sparse_tensor_value( - self, - expected_sparse_ids, - id_weight_pair.id_tensor.eval(session=sess)) - - -class SequenceEmbeddingColumnTest(test.TestCase): - - def test_get_sequence_dense_tensor(self): + self, expected, id_weight_pair.id_tensor.eval(session=sess)) + + +class SequenceEmbeddingColumnTest( + test.TestCase, parameterized.TestCase): + + @parameterized.named_parameters( + {'testcase_name': '2D', + 'inputs': sparse_tensor.SparseTensorValue( + # example 0, ids [2] + # example 1, ids [0, 1] + # example 2, ids [] + # example 3, ids [1] + indices=((0, 0), (1, 0), (1, 1), (3, 0)), + values=(2, 0, 1, 1), + dense_shape=(4, 2)), + 'expected': [ + # example 0, ids [2] + [[7., 11.], [0., 0.]], + # example 1, ids [0, 1] + [[1., 2.], [3., 5.]], + # example 2, ids [] + [[0., 0.], [0., 0.]], + # example 3, ids [1] + [[3., 5.], [0., 0.]]]}, + {'testcase_name': '3D', + 'inputs': sparse_tensor.SparseTensorValue( + # example 0, ids [[2]] + # example 1, ids [[0, 1], [2]] + # example 2, ids [] + # example 3, ids [[1], [0, 2]] + indices=((0, 0, 0), (1, 0, 0), (1, 0, 1), (1, 1, 0), + (3, 0, 0), (3, 1, 0), (3, 1, 1)), + values=(2, 0, 1, 2, 1, 0, 2), + dense_shape=(4, 2, 2)), + 'expected': [ + # example 0, ids [[2]] + [[7., 11.], [0., 0.]], + # example 1, ids [[0, 1], [2]] + [[2, 3.5], [7., 11.]], + # example 2, ids [] + [[0., 0.], [0., 0.]], + # example 3, ids [[1], [0, 2]] + [[3., 5.], [4., 6.5]]]} + ) + def test_get_sequence_dense_tensor(self, inputs, expected): vocabulary_size = 3 - sparse_input = sparse_tensor.SparseTensorValue( - # example 0, ids [2] - # example 1, ids [0, 1] - # example 2, ids [] - # example 3, ids [1] - indices=((0, 0), (1, 0), (1, 1), (3, 0)), - values=(2, 0, 1, 1), - dense_shape=(4, 2)) - embedding_dimension = 2 embedding_values = ( (1., 2.), # id 0 @@ -601,17 +866,6 @@ class SequenceEmbeddingColumnTest(test.TestCase): self.assertIsNone(partition_info) return embedding_values - expected_lookups = [ - # example 0, ids [2] - [[7., 11.], [0., 0.]], - # example 1, ids [0, 1] - [[1., 2.], [3., 5.]], - # example 2, ids [] - [[0., 0.], [0., 0.]], - # example 3, ids [1] - [[3., 5.], [0., 0.]], - ] - categorical_column = sfc.sequence_categorical_column_with_identity( key='aaa', num_buckets=vocabulary_size) embedding_column = fc.embedding_column( @@ -619,24 +873,35 @@ class SequenceEmbeddingColumnTest(test.TestCase): initializer=_initializer) embedding_lookup, _ = embedding_column._get_sequence_dense_tensor( - _LazyBuilder({'aaa': sparse_input})) + _LazyBuilder({'aaa': inputs})) global_vars = ops.get_collection(ops.GraphKeys.GLOBAL_VARIABLES) self.assertItemsEqual( ('embedding_weights:0',), tuple([v.name for v in global_vars])) with monitored_session.MonitoredSession() as sess: self.assertAllEqual(embedding_values, global_vars[0].eval(session=sess)) - self.assertAllEqual(expected_lookups, embedding_lookup.eval(session=sess)) - - def test_sequence_length(self): + self.assertAllEqual(expected, embedding_lookup.eval(session=sess)) + + @parameterized.named_parameters( + {'testcase_name': '2D', + 'inputs': sparse_tensor.SparseTensorValue( + # example 0, ids [2] + # example 1, ids [0, 1] + indices=((0, 0), (1, 0), (1, 1)), + values=(2, 0, 1), + dense_shape=(2, 2)), + 'expected_sequence_length': [1, 2]}, + {'testcase_name': '3D', + 'inputs': sparse_tensor.SparseTensorValue( + # example 0, ids [[2]] + # example 1, ids [[0, 1], [2]] + indices=((0, 0, 0), (1, 0, 0), (1, 0, 1), (1, 1, 0)), + values=(2, 0, 1, 2), + dense_shape=(2, 2, 2)), + 'expected_sequence_length': [1, 2]} + ) + def test_sequence_length(self, inputs, expected_sequence_length): vocabulary_size = 3 - sparse_input = sparse_tensor.SparseTensorValue( - # example 0, ids [2] - # example 1, ids [0, 1] - indices=((0, 0), (1, 0), (1, 1)), - values=(2, 0, 1), - dense_shape=(2, 2)) - expected_sequence_length = [1, 2] categorical_column = sfc.sequence_categorical_column_with_identity( key='aaa', num_buckets=vocabulary_size) @@ -644,7 +909,7 @@ class SequenceEmbeddingColumnTest(test.TestCase): categorical_column, dimension=2) _, sequence_length = embedding_column._get_sequence_dense_tensor( - _LazyBuilder({'aaa': sparse_input})) + _LazyBuilder({'aaa': inputs})) with monitored_session.MonitoredSession() as sess: sequence_length = sess.run(sequence_length) @@ -855,56 +1120,87 @@ class SequenceSharedEmbeddingColumnTest(test.TestCase): expected_sequence_length_b, sequence_length_b.eval(session=sess)) -class SequenceIndicatorColumnTest(test.TestCase): - - def test_get_sequence_dense_tensor(self): +class SequenceIndicatorColumnTest(test.TestCase, parameterized.TestCase): + + @parameterized.named_parameters( + {'testcase_name': '2D', + 'inputs': sparse_tensor.SparseTensorValue( + # example 0, ids [2] + # example 1, ids [0, 1] + # example 2, ids [] + # example 3, ids [1] + indices=((0, 0), (1, 0), (1, 1), (3, 0)), + values=(2, 0, 1, 1), + dense_shape=(4, 2)), + 'expected': [ + # example 0, ids [2] + [[0., 0., 1.], [0., 0., 0.]], + # example 1, ids [0, 1] + [[1., 0., 0.], [0., 1., 0.]], + # example 2, ids [] + [[0., 0., 0.], [0., 0., 0.]], + # example 3, ids [1] + [[0., 1., 0.], [0., 0., 0.]]]}, + {'testcase_name': '3D', + 'inputs': sparse_tensor.SparseTensorValue( + # example 0, ids [[2]] + # example 1, ids [[0, 1], [2]] + # example 2, ids [] + # example 3, ids [[1], [2, 2]] + indices=((0, 0, 0), (1, 0, 0), (1, 0, 1), (1, 1, 0), + (3, 0, 0), (3, 1, 0), (3, 1, 1)), + values=(2, 0, 1, 2, 1, 2, 2), + dense_shape=(4, 2, 2)), + 'expected': [ + # example 0, ids [[2]] + [[0., 0., 1.], [0., 0., 0.]], + # example 1, ids [[0, 1], [2]] + [[1., 1., 0.], [0., 0., 1.]], + # example 2, ids [] + [[0., 0., 0.], [0., 0., 0.]], + # example 3, ids [[1], [2, 2]] + [[0., 1., 0.], [0., 0., 2.]]]} + ) + def test_get_sequence_dense_tensor(self, inputs, expected): vocabulary_size = 3 - sparse_input = sparse_tensor.SparseTensorValue( - # example 0, ids [2] - # example 1, ids [0, 1] - # example 2, ids [] - # example 3, ids [1] - indices=((0, 0), (1, 0), (1, 1), (3, 0)), - values=(2, 0, 1, 1), - dense_shape=(4, 2)) - - expected_lookups = [ - # example 0, ids [2] - [[0., 0., 1.], [0., 0., 0.]], - # example 1, ids [0, 1] - [[1., 0., 0.], [0., 1., 0.]], - # example 2, ids [] - [[0., 0., 0.], [0., 0., 0.]], - # example 3, ids [1] - [[0., 1., 0.], [0., 0., 0.]], - ] categorical_column = sfc.sequence_categorical_column_with_identity( key='aaa', num_buckets=vocabulary_size) indicator_column = fc.indicator_column(categorical_column) indicator_tensor, _ = indicator_column._get_sequence_dense_tensor( - _LazyBuilder({'aaa': sparse_input})) + _LazyBuilder({'aaa': inputs})) with monitored_session.MonitoredSession() as sess: - self.assertAllEqual(expected_lookups, indicator_tensor.eval(session=sess)) - - def test_sequence_length(self): + self.assertAllEqual(expected, indicator_tensor.eval(session=sess)) + + @parameterized.named_parameters( + {'testcase_name': '2D', + 'inputs': sparse_tensor.SparseTensorValue( + # example 0, ids [2] + # example 1, ids [0, 1] + indices=((0, 0), (1, 0), (1, 1)), + values=(2, 0, 1), + dense_shape=(2, 2)), + 'expected_sequence_length': [1, 2]}, + {'testcase_name': '3D', + 'inputs': sparse_tensor.SparseTensorValue( + # example 0, ids [[2]] + # example 1, ids [[0, 1], [2]] + indices=((0, 0, 0), (1, 0, 0), (1, 0, 1), (1, 1, 0)), + values=(2, 0, 1, 2), + dense_shape=(2, 2, 2)), + 'expected_sequence_length': [1, 2]} + ) + def test_sequence_length(self, inputs, expected_sequence_length): vocabulary_size = 3 - sparse_input = sparse_tensor.SparseTensorValue( - # example 0, ids [2] - # example 1, ids [0, 1] - indices=((0, 0), (1, 0), (1, 1)), - values=(2, 0, 1), - dense_shape=(2, 2)) - expected_sequence_length = [1, 2] categorical_column = sfc.sequence_categorical_column_with_identity( key='aaa', num_buckets=vocabulary_size) indicator_column = fc.indicator_column(categorical_column) _, sequence_length = indicator_column._get_sequence_dense_tensor( - _LazyBuilder({'aaa': sparse_input})) + _LazyBuilder({'aaa': inputs})) with monitored_session.MonitoredSession() as sess: sequence_length = sess.run(sequence_length) @@ -938,7 +1234,7 @@ class SequenceIndicatorColumnTest(test.TestCase): expected_sequence_length, sequence_length.eval(session=sess)) -class SequenceNumericColumnTest(test.TestCase): +class SequenceNumericColumnTest(test.TestCase, parameterized.TestCase): def test_defaults(self): a = sfc.sequence_numeric_column('aaa') @@ -971,25 +1267,36 @@ class SequenceNumericColumnTest(test.TestCase): with self.assertRaisesRegexp(TypeError, 'must be a callable'): sfc.sequence_numeric_column('aaa', normalizer_fn='NotACallable') - def test_get_sequence_dense_tensor(self): - sparse_input = sparse_tensor.SparseTensorValue( - # example 0, values [[0.], [1]] - # example 1, [[10.]] - indices=((0, 0), (0, 1), (1, 0)), - values=(0., 1., 10.), - dense_shape=(2, 2)) - expected_dense_tensor = [ - [[0.], [1.]], - [[10.], [0.]], - ] + @parameterized.named_parameters( + {'testcase_name': '2D', + 'inputs': sparse_tensor.SparseTensorValue( + # example 0, values [0., 1] + # example 1, [10.] + indices=((0, 0), (0, 1), (1, 0)), + values=(0., 1., 10.), + dense_shape=(2, 2)), + 'expected': [ + [[0.], [1.]], + [[10.], [0.]]]}, + {'testcase_name': '3D', + 'inputs': sparse_tensor.SparseTensorValue( + # feature 0, ids [[20, 3], [5]] + # feature 1, ids [[3], [8]] + indices=((0, 0, 0), (0, 0, 1), (0, 1, 0), (1, 0, 0), (1, 1, 0)), + values=(20, 3, 5., 3., 8.), + dense_shape=(2, 2, 2)), + 'expected': [ + [[20.], [3.], [5.], [0.]], + [[3.], [0.], [8.], [0.]]]}, + ) + def test_get_sequence_dense_tensor(self, inputs, expected): numeric_column = sfc.sequence_numeric_column('aaa') dense_tensor, _ = numeric_column._get_sequence_dense_tensor( - _LazyBuilder({'aaa': sparse_input})) + _LazyBuilder({'aaa': inputs})) with monitored_session.MonitoredSession() as sess: - self.assertAllEqual( - expected_dense_tensor, dense_tensor.eval(session=sess)) + self.assertAllEqual(expected, dense_tensor.eval(session=sess)) def test_get_sequence_dense_tensor_with_normalizer_fn(self): @@ -1026,41 +1333,34 @@ class SequenceNumericColumnTest(test.TestCase): self.assertAllEqual( expected_dense_tensor, dense_tensor.eval(session=sess)) - def test_get_sequence_dense_tensor_with_shape(self): - """Tests get_sequence_dense_tensor with shape !=(1,).""" - sparse_input = sparse_tensor.SparseTensorValue( - # example 0, values [[0., 1., 2.], [3., 4., 5.]] - # example 1, [[10., 11., 12.]] - indices=((0, 0), (0, 1), (0, 2), (0, 3), (0, 4), (0, 5), - (1, 0), (1, 1), (1, 2)), - values=(0., 1., 2., 3., 4., 5., 10., 11., 12.), - dense_shape=(2, 6)) - expected_dense_tensor = [ - [[0., 1., 2.], [3., 4., 5.]], - [[10., 11., 12.], [0., 0., 0.]], - ] - numeric_column = sfc.sequence_numeric_column('aaa', shape=(3,)) - - dense_tensor, _ = numeric_column._get_sequence_dense_tensor( - _LazyBuilder({'aaa': sparse_input})) - - with monitored_session.MonitoredSession() as sess: - self.assertAllEqual( - expected_dense_tensor, dense_tensor.eval(session=sess)) - - def test_get_dense_tensor_multi_dim(self): + @parameterized.named_parameters( + {'testcase_name': '2D', + 'sparse_input': sparse_tensor.SparseTensorValue( + # example 0, values [[[0., 1.], [2., 3.]], [[4., 5.], [6., 7.]]] + # example 1, [[[10., 11.], [12., 13.]]] + indices=((0, 0), (0, 1), (0, 2), (0, 3), (0, 4), (0, 5), (0, 6), + (0, 7), (1, 0), (1, 1), (1, 2), (1, 3)), + values=(0., 1., 2., 3., 4., 5., 6., 7., 10., 11., 12., 13.), + dense_shape=(2, 8)), + 'expected_dense_tensor': [ + [[[0., 1.], [2., 3.]], [[4., 5.], [6., 7.]]], + [[[10., 11.], [12., 13.]], [[0., 0.], [0., 0.]]]]}, + {'testcase_name': '3D', + 'sparse_input': sparse_tensor.SparseTensorValue( + indices=((0, 0, 0), (0, 0, 2), (0, 0, 4), (0, 0, 6), + (0, 1, 0), (0, 1, 2), (0, 1, 4), (0, 1, 6), + (1, 0, 0), (1, 0, 2), (1, 0, 4), (1, 0, 6)), + values=(0., 1., 2., 3., 4., 5., 6., 7., 10., 11., 12., 13.), + dense_shape=(2, 2, 8)), + 'expected_dense_tensor': [ + [[[0., 0.], [1., 0.]], [[2., 0.], [3., 0.]], + [[4., 0.], [5., 0.]], [[6., 0.], [7., 0.]]], + [[[10., 0.], [11., 0.]], [[12., 0.], [13., 0.]], + [[0., 0.], [0., 0.]], [[0., 0.], [0., 0.]]]]}, + ) + def test_get_dense_tensor_multi_dim( + self, sparse_input, expected_dense_tensor): """Tests get_sequence_dense_tensor for multi-dim numeric_column.""" - sparse_input = sparse_tensor.SparseTensorValue( - # example 0, values [[[0., 1.], [2., 3.]], [[4., 5.], [6., 7.]]] - # example 1, [[[10., 11.], [12., 13.]]] - indices=((0, 0), (0, 1), (0, 2), (0, 3), (0, 4), (0, 5), (0, 6), (0, 7), - (1, 0), (1, 1), (1, 2), (1, 3)), - values=(0., 1., 2., 3., 4., 5., 6., 7., 10., 11., 12., 13.), - dense_shape=(2, 8)) - expected_dense_tensor = [ - [[[0., 1.], [2., 3.]], [[4., 5.], [6., 7.]]], - [[[10., 11.], [12., 13.]], [[0., 0.], [0., 0.]]], - ] numeric_column = sfc.sequence_numeric_column('aaa', shape=(2, 2)) dense_tensor, _ = numeric_column._get_sequence_dense_tensor( @@ -1070,43 +1370,55 @@ class SequenceNumericColumnTest(test.TestCase): self.assertAllEqual( expected_dense_tensor, dense_tensor.eval(session=sess)) - def test_sequence_length(self): - sparse_input = sparse_tensor.SparseTensorValue( - # example 0, values [[0., 1., 2.], [3., 4., 5.]] - # example 1, [[10., 11., 12.]] - indices=((0, 0), (0, 1), (0, 2), (0, 3), (0, 4), (0, 5), - (1, 0), (1, 1), (1, 2)), - values=(0., 1., 2., 3., 4., 5., 10., 11., 12.), - dense_shape=(2, 6)) - expected_sequence_length = [2, 1] - numeric_column = sfc.sequence_numeric_column('aaa', shape=(3,)) + @parameterized.named_parameters( + {'testcase_name': '2D', + 'inputs': sparse_tensor.SparseTensorValue( + # example 0, ids [2] + # example 1, ids [0, 1] + indices=((0, 0), (1, 0), (1, 1)), + values=(2., 0., 1.), + dense_shape=(2, 2)), + 'expected_sequence_length': [1, 2], + 'shape': (1,)}, + {'testcase_name': '3D', + 'inputs': sparse_tensor.SparseTensorValue( + # example 0, ids [[2]] + # example 1, ids [[0, 1], [2]] + indices=((0, 0, 0), (1, 0, 0), (1, 0, 1), (1, 1, 0)), + values=(2., 0., 1., 2.), + dense_shape=(2, 2, 2)), + 'expected_sequence_length': [1, 2], + 'shape': (1,)}, + {'testcase_name': '2D_with_shape', + 'inputs': sparse_tensor.SparseTensorValue( + # example 0, ids [2] + # example 1, ids [0, 1] + indices=((0, 0), (1, 0), (1, 1)), + values=(2., 0., 1.), + dense_shape=(2, 2)), + 'expected_sequence_length': [1, 1], + 'shape': (2,)}, + {'testcase_name': '3D_with_shape', + 'inputs': sparse_tensor.SparseTensorValue( + # example 0, ids [[2]] + # example 1, ids [[0, 1], [2]] + indices=((0, 0, 0), (1, 0, 0), (1, 0, 1), (1, 1, 0)), + values=(2., 0., 1., 2.), + dense_shape=(2, 2, 2)), + 'expected_sequence_length': [1, 2], + 'shape': (2,)}, + ) + def test_sequence_length(self, inputs, expected_sequence_length, shape): + numeric_column = sfc.sequence_numeric_column('aaa', shape=shape) _, sequence_length = numeric_column._get_sequence_dense_tensor( - _LazyBuilder({'aaa': sparse_input})) + _LazyBuilder({'aaa': inputs})) with monitored_session.MonitoredSession() as sess: sequence_length = sess.run(sequence_length) self.assertAllEqual(expected_sequence_length, sequence_length) self.assertEqual(np.int64, sequence_length.dtype) - def test_sequence_length_with_shape(self): - """Tests _sequence_length with shape !=(1,).""" - sparse_input = sparse_tensor.SparseTensorValue( - # example 0, values [[0.], [1]] - # example 1, [[10.]] - indices=((0, 0), (0, 1), (1, 0)), - values=(0., 1., 10.), - dense_shape=(2, 2)) - expected_sequence_length = [2, 1] - numeric_column = sfc.sequence_numeric_column('aaa') - - _, sequence_length = numeric_column._get_sequence_dense_tensor( - _LazyBuilder({'aaa': sparse_input})) - - with monitored_session.MonitoredSession() as sess: - self.assertAllEqual( - expected_sequence_length, sequence_length.eval(session=sess)) - def test_sequence_length_with_empty_rows(self): """Tests _sequence_length when some examples do not have ids.""" sparse_input = sparse_tensor.SparseTensorValue( -- cgit v1.2.3