diff options
author | 2017-05-09 12:32:41 -0700 | |
---|---|---|
committer | 2017-05-10 16:51:41 -0700 | |
commit | 1bad658d63e5fb21d321cd680e7451c96e032f7e (patch) | |
tree | 715a70bafaa94e6f6d977e2129fc8222ecfbd857 | |
parent | 557ab679e730d250d665695d1556d5c2b25f7f07 (diff) |
Moved transform_features.
removed pylint lines from test.
PiperOrigin-RevId: 155538004
-rw-r--r-- | tensorflow/python/feature_column/feature_column.py | 44 | ||||
-rw-r--r-- | tensorflow/python/feature_column/feature_column_test.py | 171 |
2 files changed, 132 insertions, 83 deletions
diff --git a/tensorflow/python/feature_column/feature_column.py b/tensorflow/python/feature_column/feature_column.py index 70c38574fb..6efab10efb 100644 --- a/tensorflow/python/feature_column/feature_column.py +++ b/tensorflow/python/feature_column/feature_column.py @@ -313,6 +313,50 @@ def make_linear_model(features, return predictions +def _transform_features(features, feature_columns): + """Returns transformed features based on features columns passed in. + + Please note that most probably you would not need to use this function. Please + check `make_input_layer` and `make_linear_model` to see whether they will + satisfy your use case or not. + + Example: + + ```python + # Define features and transformations + crosses_a_x_b = crossed_column( + columns=["sparse_feature_a", "sparse_feature_b"], hash_bucket_size=10000) + price_buckets = bucketized_column( + source_column=numeric_column("price"), boundaries=[...]) + + columns = [crosses_a_x_b, price_buckets] + features = tf.parse_example(..., features=parse_example_spec(columns)) + transformed = transform_features(features=features, feature_columns=columns) + + assertCountEqual(columns, transformed.keys()) + ``` + + Args: + features: A mapping from key to tensors. `FeatureColumn`s look up via these + keys. For example `numeric_column('price') will look at 'price' key in + this dict. Values can be a `SparseTensor` or a `Tensor` depends on + corresponding `FeatureColumn`. + feature_columns: An iterable containing all the `FeatureColumn`s. + + Returns: + A `dict` mapping FeatureColumn to `Tensor` and `SparseTensor` values. + """ + _check_feature_columns(feature_columns) + outputs = {} + with ops.name_scope( + None, default_name='transform_features', values=features.values()): + builder = _LazyBuilder(features) + for column in sorted(feature_columns, key=lambda x: x.name): + with ops.name_scope(None, default_name=column.name): + outputs[column] = builder.get(column) + return outputs + + def numeric_column(key, shape=(1,), default_value=None, diff --git a/tensorflow/python/feature_column/feature_column_test.py b/tensorflow/python/feature_column/feature_column_test.py index 5201811831..bd9a7c8df8 100644 --- a/tensorflow/python/feature_column/feature_column_test.py +++ b/tensorflow/python/feature_column/feature_column_test.py @@ -274,12 +274,9 @@ class NumericColumnTest(test.TestCase): return input_tensor + 2. price = fc.numeric_column('price', shape=[2], normalizer_fn=_increment_two) - builder = fc._LazyBuilder({ - 'price': [[1., 2.], [5., 6.]] - }) - output = builder.get(price) + output = fc._transform_features({'price': [[1., 2.], [5., 6.]]}, [price]) with self.test_session(): - self.assertAllEqual([[3., 4.], [7., 8.]], output.eval()) + self.assertAllEqual([[3., 4.], [7., 8.]], output[price].eval()) def test_get_dense_tensor(self): @@ -403,12 +400,12 @@ class BucketizedColumnTest(test.TestCase): price = fc.numeric_column('price', shape=[2]) bucketized_price = fc.bucketized_column(price, boundaries=[0, 2, 4, 6]) with ops.Graph().as_default(): - builder = fc._LazyBuilder({ + transformed_tensor = fc._transform_features({ 'price': [[-1., 1.], [5., 6.]] - }) - transformed_tensor = builder.get(bucketized_price) + }, [bucketized_price]) with _initialized_session(): - self.assertAllEqual([[0, 1], [3, 4]], transformed_tensor.eval()) + self.assertAllEqual([[0, 1], [3, 4]], + transformed_tensor[bucketized_price].eval()) def test_get_dense_tensor_one_input_value(self): """Tests _get_dense_tensor() for input with shape=[1].""" @@ -584,9 +581,7 @@ class HashedCategoricalColumnTest(test.TestCase): for column in (original, copy.deepcopy(original)): self.assertEqual('aaa', column.name) self.assertEqual(10, column.hash_bucket_size) - # pylint: disable=protected-access self.assertEqual(10, column._num_buckets) - # pylint: enable=protected-access self.assertEqual(dtypes.string, column.dtype) def test_parse_config(self): @@ -607,8 +602,8 @@ class HashedCategoricalColumnTest(test.TestCase): values=['omar', 'stringer', 'marlo'], indices=[[0, 0], [1, 0], [1, 1]], dense_shape=[2, 2]) - builder = fc._LazyBuilder({'wire': wire_tensor}) - output = builder.get(hashed_sparse) + outputs = fc._transform_features({'wire': wire_tensor}, [hashed_sparse]) + output = outputs[hashed_sparse] # Check exact hashed output. If hashing changes this test will break. expected_values = [6, 4, 1] with self.test_session(): @@ -1225,23 +1220,19 @@ class VocabularyFileCategoricalColumnTest(test.TestCase): column = fc.categorical_column_with_vocabulary_file( key='aaa', vocabulary_file='path_to_file', vocabulary_size=3) self.assertEqual('aaa', column.name) - # pylint: disable=protected-access self.assertEqual(3, column._num_buckets) self.assertEqual({ 'aaa': parsing_ops.VarLenFeature(dtypes.string) }, column._parse_example_config) - # pylint: enable=protected-access def test_all_constructor_args(self): column = fc.categorical_column_with_vocabulary_file( key='aaa', vocabulary_file='path_to_file', vocabulary_size=3, num_oov_buckets=4, dtype=dtypes.int32) - # pylint: disable=protected-access self.assertEqual(7, column._num_buckets) self.assertEqual({ 'aaa': parsing_ops.VarLenFeature(dtypes.int32) }, column._parse_example_config) - # pylint: enable=protected-access def test_deep_copy(self): """Tests deepcopy of categorical_column_with_hash_bucket.""" @@ -1250,12 +1241,10 @@ class VocabularyFileCategoricalColumnTest(test.TestCase): num_oov_buckets=4, dtype=dtypes.int32) for column in (original, copy.deepcopy(original)): self.assertEqual('aaa', column.name) - # pylint: disable=protected-access self.assertEqual(7, column._num_buckets) self.assertEqual({ 'aaa': parsing_ops.VarLenFeature(dtypes.int32) }, column._parse_example_config) - # pylint: enable=protected-access def test_vocabulary_file_none(self): with self.assertRaisesRegexp(ValueError, 'Missing vocabulary_file'): @@ -1274,9 +1263,7 @@ class VocabularyFileCategoricalColumnTest(test.TestCase): indices=((0, 0), (1, 0), (1, 1)), values=('marlo', 'skywalker', 'omar'), dense_shape=(2, 2)) - # pylint: disable=protected-access column._get_sparse_tensors(fc._LazyBuilder({'aaa': inputs})) - # pylint: enable=protected-access with self.assertRaisesRegexp(errors.OpError, 'file_does_not_exist'): with self.test_session(): lookup_ops.tables_initializer().run() @@ -1304,9 +1291,7 @@ class VocabularyFileCategoricalColumnTest(test.TestCase): indices=((0, 0), (1, 0), (1, 1)), values=('marlo', 'skywalker', 'omar'), dense_shape=(2, 2)) - # pylint: disable=protected-access column._get_sparse_tensors(fc._LazyBuilder({'aaa': inputs})) - # pylint: enable=protected-access with self.assertRaisesRegexp(errors.OpError, 'Invalid vocab_size'): with self.test_session(): lookup_ops.tables_initializer().run() @@ -1344,9 +1329,7 @@ class VocabularyFileCategoricalColumnTest(test.TestCase): values=(12, 24, 36), dense_shape=(2, 2)) with self.assertRaisesRegexp(ValueError, 'dtype must be compatible'): - # pylint: disable=protected-access column._get_sparse_tensors(fc._LazyBuilder({'aaa': inputs})) - # pylint: enable=protected-access def test_invalid_input_dtype_string(self): column = fc.categorical_column_with_vocabulary_file( @@ -1359,9 +1342,7 @@ class VocabularyFileCategoricalColumnTest(test.TestCase): values=('omar', 'stringer', 'marlo'), dense_shape=(2, 2)) with self.assertRaisesRegexp(ValueError, 'dtype must be compatible'): - # pylint: disable=protected-access column._get_sparse_tensors(fc._LazyBuilder({'aaa': inputs})) - # pylint: enable=protected-access def test_get_sparse_tensors(self): column = fc.categorical_column_with_vocabulary_file( @@ -1372,10 +1353,8 @@ class VocabularyFileCategoricalColumnTest(test.TestCase): indices=((0, 0), (1, 0), (1, 1)), values=('marlo', 'skywalker', 'omar'), dense_shape=(2, 2)) - # pylint: disable=protected-access id_weight_pair = column._get_sparse_tensors( fc._LazyBuilder({'aaa': inputs})) - # pylint: enable=protected-access self.assertIsNone(id_weight_pair.weight_tensor) with _initialized_session(): _assert_sparse_tensor_value( @@ -1386,16 +1365,33 @@ class VocabularyFileCategoricalColumnTest(test.TestCase): dense_shape=inputs.dense_shape), id_weight_pair.id_tensor.eval()) + def test_transform_feature(self): + column = fc.categorical_column_with_vocabulary_file( + key='aaa', + vocabulary_file=self._wire_vocabulary_file_name, + vocabulary_size=self._wire_vocabulary_size) + inputs = sparse_tensor.SparseTensorValue( + indices=((0, 0), (1, 0), (1, 1)), + values=('marlo', 'skywalker', 'omar'), + dense_shape=(2, 2)) + id_tensor = fc._transform_features({'aaa': inputs}, [column])[column] + with _initialized_session(): + _assert_sparse_tensor_value(self, + sparse_tensor.SparseTensorValue( + indices=inputs.indices, + values=np.array( + (2, -1, 0), dtype=np.int64), + dense_shape=inputs.dense_shape), + id_tensor.eval()) + def test_get_sparse_tensors_dense_input(self): column = fc.categorical_column_with_vocabulary_file( key='aaa', vocabulary_file=self._wire_vocabulary_file_name, vocabulary_size=self._wire_vocabulary_size) - # pylint: disable=protected-access id_weight_pair = column._get_sparse_tensors(fc._LazyBuilder({ 'aaa': (('marlo', ''), ('skywalker', 'omar')) })) - # pylint: enable=protected-access self.assertIsNone(id_weight_pair.weight_tensor) with _initialized_session(): _assert_sparse_tensor_value( @@ -1416,10 +1412,8 @@ class VocabularyFileCategoricalColumnTest(test.TestCase): indices=((0, 0), (1, 0), (1, 1)), values=('marlo', 'skywalker', 'omar'), dense_shape=(2, 2)) - # pylint: disable=protected-access id_weight_pair = column._get_sparse_tensors( fc._LazyBuilder({'aaa': inputs})) - # pylint: enable=protected-access self.assertIsNone(id_weight_pair.weight_tensor) with _initialized_session(): _assert_sparse_tensor_value( @@ -1440,10 +1434,8 @@ class VocabularyFileCategoricalColumnTest(test.TestCase): indices=((0, 0), (1, 0), (1, 1), (1, 2)), values=('marlo', 'skywalker', 'omar', 'heisenberg'), dense_shape=(2, 3)) - # pylint: disable=protected-access id_weight_pair = column._get_sparse_tensors( fc._LazyBuilder({'aaa': inputs})) - # pylint: enable=protected-access self.assertIsNone(id_weight_pair.weight_tensor) with _initialized_session(): _assert_sparse_tensor_value( @@ -1466,10 +1458,8 @@ class VocabularyFileCategoricalColumnTest(test.TestCase): indices=((0, 0), (1, 0), (1, 1)), values=('marlo', 'skywalker', 'omar'), dense_shape=(2, 2)) - # pylint: disable=protected-access id_weight_pair = column._get_sparse_tensors( fc._LazyBuilder({'aaa': inputs})) - # pylint: enable=protected-access self.assertIsNone(id_weight_pair.weight_tensor) with _initialized_session(): _assert_sparse_tensor_value( @@ -1490,10 +1480,8 @@ class VocabularyFileCategoricalColumnTest(test.TestCase): indices=((0, 0), (1, 0), (1, 1), (2, 2)), values=(11, 100, 30, 22), dense_shape=(3, 3)) - # pylint: disable=protected-access id_weight_pair = column._get_sparse_tensors( fc._LazyBuilder({'aaa': inputs})) - # pylint: enable=protected-access self.assertIsNone(id_weight_pair.weight_tensor) with _initialized_session(): _assert_sparse_tensor_value( @@ -1512,11 +1500,9 @@ class VocabularyFileCategoricalColumnTest(test.TestCase): vocabulary_size=self._warriors_vocabulary_size, dtype=dtypes.int32, default_value=default_value) - # pylint: disable=protected-access id_weight_pair = column._get_sparse_tensors(fc._LazyBuilder({ 'aaa': ((11, -1, -1), (100, 30, -1), (-1, -1, 22)) })) - # pylint: enable=protected-access self.assertIsNone(id_weight_pair.weight_tensor) with _initialized_session(): _assert_sparse_tensor_value( @@ -1538,10 +1524,8 @@ class VocabularyFileCategoricalColumnTest(test.TestCase): indices=((0, 0), (1, 0), (1, 1), (2, 2)), values=(11, 100, 30, 22), dense_shape=(3, 3)) - # pylint: disable=protected-access id_weight_pair = column._get_sparse_tensors( fc._LazyBuilder({'aaa': inputs})) - # pylint: enable=protected-access self.assertIsNone(id_weight_pair.weight_tensor) with _initialized_session(): _assert_sparse_tensor_value( @@ -1584,34 +1568,28 @@ class VocabularyListCategoricalColumnTest(test.TestCase): column = fc.categorical_column_with_vocabulary_list( key='aaa', vocabulary_list=('omar', 'stringer', 'marlo')) self.assertEqual('aaa', column.name) - # pylint: disable=protected-access self.assertEqual(3, column._num_buckets) self.assertEqual({ 'aaa': parsing_ops.VarLenFeature(dtypes.string) }, column._parse_example_config) - # pylint: enable=protected-access def test_defaults_int(self): column = fc.categorical_column_with_vocabulary_list( key='aaa', vocabulary_list=(12, 24, 36)) self.assertEqual('aaa', column.name) - # pylint: disable=protected-access self.assertEqual(3, column._num_buckets) self.assertEqual({ 'aaa': parsing_ops.VarLenFeature(dtypes.int64) }, column._parse_example_config) - # pylint: enable=protected-access def test_all_constructor_args(self): column = fc.categorical_column_with_vocabulary_list( key='aaa', vocabulary_list=(12, 24, 36), dtype=dtypes.int32, default_value=-99) - # pylint: disable=protected-access self.assertEqual(3, column._num_buckets) self.assertEqual({ 'aaa': parsing_ops.VarLenFeature(dtypes.int32) }, column._parse_example_config) - # pylint: enable=protected-access def test_deep_copy(self): """Tests deepcopy of categorical_column_with_hash_bucket.""" @@ -1619,12 +1597,10 @@ class VocabularyListCategoricalColumnTest(test.TestCase): key='aaa', vocabulary_list=(12, 24, 36), dtype=dtypes.int32) for column in (original, copy.deepcopy(original)): self.assertEqual('aaa', column.name) - # pylint: disable=protected-access self.assertEqual(3, column._num_buckets) self.assertEqual({ 'aaa': parsing_ops.VarLenFeature(dtypes.int32) }, column._parse_example_config) - # pylint: enable=protected-access def test_invalid_dtype(self): with self.assertRaisesRegexp(ValueError, 'dtype must be string or integer'): @@ -1677,9 +1653,7 @@ class VocabularyListCategoricalColumnTest(test.TestCase): values=(12, 24, 36), dense_shape=(2, 2)) with self.assertRaisesRegexp(ValueError, 'dtype must be compatible'): - # pylint: disable=protected-access column._get_sparse_tensors(fc._LazyBuilder({'aaa': inputs})) - # pylint: enable=protected-access def test_invalid_input_dtype_string(self): column = fc.categorical_column_with_vocabulary_list( @@ -1690,9 +1664,7 @@ class VocabularyListCategoricalColumnTest(test.TestCase): values=('omar', 'stringer', 'marlo'), dense_shape=(2, 2)) with self.assertRaisesRegexp(ValueError, 'dtype must be compatible'): - # pylint: disable=protected-access column._get_sparse_tensors(fc._LazyBuilder({'aaa': inputs})) - # pylint: enable=protected-access def test_get_sparse_tensors(self): column = fc.categorical_column_with_vocabulary_list( @@ -1702,10 +1674,8 @@ class VocabularyListCategoricalColumnTest(test.TestCase): indices=((0, 0), (1, 0), (1, 1)), values=('marlo', 'skywalker', 'omar'), dense_shape=(2, 2)) - # pylint: disable=protected-access id_weight_pair = column._get_sparse_tensors( fc._LazyBuilder({'aaa': inputs})) - # pylint: enable=protected-access self.assertIsNone(id_weight_pair.weight_tensor) with _initialized_session(): _assert_sparse_tensor_value( @@ -1720,11 +1690,9 @@ class VocabularyListCategoricalColumnTest(test.TestCase): column = fc.categorical_column_with_vocabulary_list( key='aaa', vocabulary_list=('omar', 'stringer', 'marlo')) - # pylint: disable=protected-access id_weight_pair = column._get_sparse_tensors(fc._LazyBuilder({ 'aaa': (('marlo', ''), ('skywalker', 'omar')) })) - # pylint: enable=protected-access self.assertIsNone(id_weight_pair.weight_tensor) with _initialized_session(): _assert_sparse_tensor_value( @@ -1744,10 +1712,8 @@ class VocabularyListCategoricalColumnTest(test.TestCase): indices=((0, 0), (1, 0), (1, 1)), values=('marlo', 'skywalker', 'omar'), dense_shape=(2, 2)) - # pylint: disable=protected-access id_weight_pair = column._get_sparse_tensors( fc._LazyBuilder({'aaa': inputs})) - # pylint: enable=protected-access self.assertIsNone(id_weight_pair.weight_tensor) with _initialized_session(): _assert_sparse_tensor_value( @@ -1767,10 +1733,8 @@ class VocabularyListCategoricalColumnTest(test.TestCase): indices=((0, 0), (1, 0), (1, 1), (2, 2)), values=np.array((11, 100, 30, 22), dtype=np.int32), dense_shape=(3, 3)) - # pylint: disable=protected-access id_weight_pair = column._get_sparse_tensors( fc._LazyBuilder({'aaa': inputs})) - # pylint: enable=protected-access self.assertIsNone(id_weight_pair.weight_tensor) with _initialized_session(): _assert_sparse_tensor_value( @@ -1788,13 +1752,11 @@ class VocabularyListCategoricalColumnTest(test.TestCase): vocabulary_list=np.array((30, 35, 11, 23, 22), dtype=np.int32), dtype=dtypes.int32, default_value=default_value) - # pylint: disable=protected-access id_weight_pair = column._get_sparse_tensors(fc._LazyBuilder({ 'aaa': np.array( ((11, -1, -1), (100, 30, -1), (-1, -1, 22)), dtype=np.int32) })) - # pylint: enable=protected-access self.assertIsNone(id_weight_pair.weight_tensor) with _initialized_session(): _assert_sparse_tensor_value( @@ -1834,24 +1796,20 @@ class IdentityCategoricalColumnTest(test.TestCase): def test_constructor(self): column = fc.categorical_column_with_identity(key='aaa', num_buckets=3) self.assertEqual('aaa', column.name) - # pylint: disable=protected-access self.assertEqual(3, column._num_buckets) self.assertEqual({ 'aaa': parsing_ops.VarLenFeature(dtypes.int64) }, column._parse_example_config) - # pylint: enable=protected-access def test_deep_copy(self): """Tests deepcopy of categorical_column_with_hash_bucket.""" original = fc.categorical_column_with_identity(key='aaa', num_buckets=3) for column in (original, copy.deepcopy(original)): self.assertEqual('aaa', column.name) - # pylint: disable=protected-access self.assertEqual(3, column._num_buckets) self.assertEqual({ 'aaa': parsing_ops.VarLenFeature(dtypes.int64) }, column._parse_example_config) - # pylint: enable=protected-access def test_invalid_num_buckets_zero(self): with self.assertRaisesRegexp(ValueError, 'num_buckets 0 < 1'): @@ -1878,9 +1836,7 @@ class IdentityCategoricalColumnTest(test.TestCase): values=('omar', 'stringer', 'marlo'), dense_shape=(2, 2)) with self.assertRaisesRegexp(ValueError, 'Invalid input, not integer'): - # pylint: disable=protected-access column._get_sparse_tensors(fc._LazyBuilder({'aaa': inputs})) - # pylint: enable=protected-access def test_get_sparse_tensors(self): column = fc.categorical_column_with_identity(key='aaa', num_buckets=3) @@ -1888,10 +1844,8 @@ class IdentityCategoricalColumnTest(test.TestCase): indices=((0, 0), (1, 0), (1, 1)), values=(0, 1, 0), dense_shape=(2, 2)) - # pylint: disable=protected-access id_weight_pair = column._get_sparse_tensors( fc._LazyBuilder({'aaa': inputs})) - # pylint: enable=protected-access self.assertIsNone(id_weight_pair.weight_tensor) with _initialized_session(): _assert_sparse_tensor_value( @@ -1904,11 +1858,9 @@ class IdentityCategoricalColumnTest(test.TestCase): def test_get_sparse_tensors_dense_input(self): column = fc.categorical_column_with_identity(key='aaa', num_buckets=3) - # pylint: disable=protected-access id_weight_pair = column._get_sparse_tensors(fc._LazyBuilder({ 'aaa': ((0, -1), (1, 0)) })) - # pylint: enable=protected-access self.assertIsNone(id_weight_pair.weight_tensor) with _initialized_session(): _assert_sparse_tensor_value( @@ -1925,10 +1877,8 @@ class IdentityCategoricalColumnTest(test.TestCase): indices=((0, 0), (1, 0), (1, 1)), values=(1, -1, 0), dense_shape=(2, 2)) - # pylint: disable=protected-access id_weight_pair = column._get_sparse_tensors( fc._LazyBuilder({'aaa': inputs})) - # pylint: enable=protected-access self.assertIsNone(id_weight_pair.weight_tensor) with _initialized_session(): with self.assertRaisesRegexp( @@ -1941,10 +1891,8 @@ class IdentityCategoricalColumnTest(test.TestCase): indices=((0, 0), (1, 0), (1, 1)), values=(1, 99, 0), dense_shape=(2, 2)) - # pylint: disable=protected-access id_weight_pair = column._get_sparse_tensors( fc._LazyBuilder({'aaa': inputs})) - # pylint: enable=protected-access self.assertIsNone(id_weight_pair.weight_tensor) with _initialized_session(): with self.assertRaisesRegexp( @@ -1958,10 +1906,8 @@ class IdentityCategoricalColumnTest(test.TestCase): indices=((0, 0), (1, 0), (1, 1)), values=(1, -1, 99), dense_shape=(2, 2)) - # pylint: disable=protected-access id_weight_pair = column._get_sparse_tensors( fc._LazyBuilder({'aaa': inputs})) - # pylint: enable=protected-access self.assertIsNone(id_weight_pair.weight_tensor) with _initialized_session(): _assert_sparse_tensor_value( @@ -1982,10 +1928,8 @@ class IdentityCategoricalColumnTest(test.TestCase): indices=input_indices, values=input_values, dense_shape=input_shape) - # pylint: disable=protected-access id_weight_pair = column._get_sparse_tensors( fc._LazyBuilder({'aaa': inputs})) - # pylint: enable=protected-access self.assertIsNone(id_weight_pair.weight_tensor) with _initialized_session(): _assert_sparse_tensor_value( @@ -2022,5 +1966,66 @@ class IdentityCategoricalColumnTest(test.TestCase): self.assertAllClose(((1.,), (5.,)), predictions.eval()) +class TransformFeaturesTest(test.TestCase): + + # All transform tests are distributed in column test. + # Here we only test multi column case and naming + def transform_multi_column(self): + bucketized_price = fc.bucketized_column( + fc.numeric_column('price'), boundaries=[0, 2, 4, 6]) + hashed_sparse = fc.categorical_column_with_hash_bucket('wire', 10) + with ops.Graph().as_default(): + features = { + 'price': [[-1.], [5.]], + 'wire': + sparse_tensor.SparseTensor( + values=['omar', 'stringer', 'marlo'], + indices=[[0, 0], [1, 0], [1, 1]], + dense_shape=[2, 2]) + } + transformed = fc._transform_features(features, + [bucketized_price, hashed_sparse]) + with _initialized_session(): + self.assertIn(bucketized_price.name, transformed[bucketized_price].name) + self.assertAllEqual([[0], [3]], transformed[bucketized_price].eval()) + self.assertIn(hashed_sparse.name, transformed[hashed_sparse].name) + self.assertAllEqual([6, 4, 1], transformed[hashed_sparse].values.eval()) + + def test_column_order(self): + """When the column is both dense and sparse, uses sparse tensors.""" + + class _LoggerColumn(fc._FeatureColumn): + + def __init__(self, name): + self._name = name + + @property + def name(self): + return self._name + + def _transform_feature(self, inputs): + del inputs + self.call_order = call_logger['count'] + call_logger['count'] += 1 + return 'Anything' + + @property + def _parse_example_config(self): + pass + + with ops.Graph().as_default(): + column1 = _LoggerColumn('1') + column2 = _LoggerColumn('2') + call_logger = {'count': 0} + fc._transform_features({}, [column1, column2]) + self.assertEqual(0, column1.call_order) + self.assertEqual(1, column2.call_order) + + call_logger = {'count': 0} + fc._transform_features({}, [column2, column1]) + self.assertEqual(0, column1.call_order) + self.assertEqual(1, column2.call_order) + + if __name__ == '__main__': test.main() |