aboutsummaryrefslogtreecommitdiffhomepage
diff options
context:
space:
mode:
authorGravatar Mustafa Ispir <ispir@google.com>2017-05-09 12:32:41 -0700
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2017-05-10 16:51:41 -0700
commit1bad658d63e5fb21d321cd680e7451c96e032f7e (patch)
tree715a70bafaa94e6f6d977e2129fc8222ecfbd857
parent557ab679e730d250d665695d1556d5c2b25f7f07 (diff)
Moved transform_features.
removed pylint lines from test. PiperOrigin-RevId: 155538004
-rw-r--r--tensorflow/python/feature_column/feature_column.py44
-rw-r--r--tensorflow/python/feature_column/feature_column_test.py171
2 files changed, 132 insertions, 83 deletions
diff --git a/tensorflow/python/feature_column/feature_column.py b/tensorflow/python/feature_column/feature_column.py
index 70c38574fb..6efab10efb 100644
--- a/tensorflow/python/feature_column/feature_column.py
+++ b/tensorflow/python/feature_column/feature_column.py
@@ -313,6 +313,50 @@ def make_linear_model(features,
return predictions
+def _transform_features(features, feature_columns):
+ """Returns transformed features based on features columns passed in.
+
+ Please note that most probably you would not need to use this function. Please
+ check `make_input_layer` and `make_linear_model` to see whether they will
+ satisfy your use case or not.
+
+ Example:
+
+ ```python
+ # Define features and transformations
+ crosses_a_x_b = crossed_column(
+ columns=["sparse_feature_a", "sparse_feature_b"], hash_bucket_size=10000)
+ price_buckets = bucketized_column(
+ source_column=numeric_column("price"), boundaries=[...])
+
+ columns = [crosses_a_x_b, price_buckets]
+ features = tf.parse_example(..., features=parse_example_spec(columns))
+ transformed = transform_features(features=features, feature_columns=columns)
+
+ assertCountEqual(columns, transformed.keys())
+ ```
+
+ Args:
+ features: A mapping from key to tensors. `FeatureColumn`s look up via these
+ keys. For example `numeric_column('price') will look at 'price' key in
+ this dict. Values can be a `SparseTensor` or a `Tensor` depends on
+ corresponding `FeatureColumn`.
+ feature_columns: An iterable containing all the `FeatureColumn`s.
+
+ Returns:
+ A `dict` mapping FeatureColumn to `Tensor` and `SparseTensor` values.
+ """
+ _check_feature_columns(feature_columns)
+ outputs = {}
+ with ops.name_scope(
+ None, default_name='transform_features', values=features.values()):
+ builder = _LazyBuilder(features)
+ for column in sorted(feature_columns, key=lambda x: x.name):
+ with ops.name_scope(None, default_name=column.name):
+ outputs[column] = builder.get(column)
+ return outputs
+
+
def numeric_column(key,
shape=(1,),
default_value=None,
diff --git a/tensorflow/python/feature_column/feature_column_test.py b/tensorflow/python/feature_column/feature_column_test.py
index 5201811831..bd9a7c8df8 100644
--- a/tensorflow/python/feature_column/feature_column_test.py
+++ b/tensorflow/python/feature_column/feature_column_test.py
@@ -274,12 +274,9 @@ class NumericColumnTest(test.TestCase):
return input_tensor + 2.
price = fc.numeric_column('price', shape=[2], normalizer_fn=_increment_two)
- builder = fc._LazyBuilder({
- 'price': [[1., 2.], [5., 6.]]
- })
- output = builder.get(price)
+ output = fc._transform_features({'price': [[1., 2.], [5., 6.]]}, [price])
with self.test_session():
- self.assertAllEqual([[3., 4.], [7., 8.]], output.eval())
+ self.assertAllEqual([[3., 4.], [7., 8.]], output[price].eval())
def test_get_dense_tensor(self):
@@ -403,12 +400,12 @@ class BucketizedColumnTest(test.TestCase):
price = fc.numeric_column('price', shape=[2])
bucketized_price = fc.bucketized_column(price, boundaries=[0, 2, 4, 6])
with ops.Graph().as_default():
- builder = fc._LazyBuilder({
+ transformed_tensor = fc._transform_features({
'price': [[-1., 1.], [5., 6.]]
- })
- transformed_tensor = builder.get(bucketized_price)
+ }, [bucketized_price])
with _initialized_session():
- self.assertAllEqual([[0, 1], [3, 4]], transformed_tensor.eval())
+ self.assertAllEqual([[0, 1], [3, 4]],
+ transformed_tensor[bucketized_price].eval())
def test_get_dense_tensor_one_input_value(self):
"""Tests _get_dense_tensor() for input with shape=[1]."""
@@ -584,9 +581,7 @@ class HashedCategoricalColumnTest(test.TestCase):
for column in (original, copy.deepcopy(original)):
self.assertEqual('aaa', column.name)
self.assertEqual(10, column.hash_bucket_size)
- # pylint: disable=protected-access
self.assertEqual(10, column._num_buckets)
- # pylint: enable=protected-access
self.assertEqual(dtypes.string, column.dtype)
def test_parse_config(self):
@@ -607,8 +602,8 @@ class HashedCategoricalColumnTest(test.TestCase):
values=['omar', 'stringer', 'marlo'],
indices=[[0, 0], [1, 0], [1, 1]],
dense_shape=[2, 2])
- builder = fc._LazyBuilder({'wire': wire_tensor})
- output = builder.get(hashed_sparse)
+ outputs = fc._transform_features({'wire': wire_tensor}, [hashed_sparse])
+ output = outputs[hashed_sparse]
# Check exact hashed output. If hashing changes this test will break.
expected_values = [6, 4, 1]
with self.test_session():
@@ -1225,23 +1220,19 @@ class VocabularyFileCategoricalColumnTest(test.TestCase):
column = fc.categorical_column_with_vocabulary_file(
key='aaa', vocabulary_file='path_to_file', vocabulary_size=3)
self.assertEqual('aaa', column.name)
- # pylint: disable=protected-access
self.assertEqual(3, column._num_buckets)
self.assertEqual({
'aaa': parsing_ops.VarLenFeature(dtypes.string)
}, column._parse_example_config)
- # pylint: enable=protected-access
def test_all_constructor_args(self):
column = fc.categorical_column_with_vocabulary_file(
key='aaa', vocabulary_file='path_to_file', vocabulary_size=3,
num_oov_buckets=4, dtype=dtypes.int32)
- # pylint: disable=protected-access
self.assertEqual(7, column._num_buckets)
self.assertEqual({
'aaa': parsing_ops.VarLenFeature(dtypes.int32)
}, column._parse_example_config)
- # pylint: enable=protected-access
def test_deep_copy(self):
"""Tests deepcopy of categorical_column_with_hash_bucket."""
@@ -1250,12 +1241,10 @@ class VocabularyFileCategoricalColumnTest(test.TestCase):
num_oov_buckets=4, dtype=dtypes.int32)
for column in (original, copy.deepcopy(original)):
self.assertEqual('aaa', column.name)
- # pylint: disable=protected-access
self.assertEqual(7, column._num_buckets)
self.assertEqual({
'aaa': parsing_ops.VarLenFeature(dtypes.int32)
}, column._parse_example_config)
- # pylint: enable=protected-access
def test_vocabulary_file_none(self):
with self.assertRaisesRegexp(ValueError, 'Missing vocabulary_file'):
@@ -1274,9 +1263,7 @@ class VocabularyFileCategoricalColumnTest(test.TestCase):
indices=((0, 0), (1, 0), (1, 1)),
values=('marlo', 'skywalker', 'omar'),
dense_shape=(2, 2))
- # pylint: disable=protected-access
column._get_sparse_tensors(fc._LazyBuilder({'aaa': inputs}))
- # pylint: enable=protected-access
with self.assertRaisesRegexp(errors.OpError, 'file_does_not_exist'):
with self.test_session():
lookup_ops.tables_initializer().run()
@@ -1304,9 +1291,7 @@ class VocabularyFileCategoricalColumnTest(test.TestCase):
indices=((0, 0), (1, 0), (1, 1)),
values=('marlo', 'skywalker', 'omar'),
dense_shape=(2, 2))
- # pylint: disable=protected-access
column._get_sparse_tensors(fc._LazyBuilder({'aaa': inputs}))
- # pylint: enable=protected-access
with self.assertRaisesRegexp(errors.OpError, 'Invalid vocab_size'):
with self.test_session():
lookup_ops.tables_initializer().run()
@@ -1344,9 +1329,7 @@ class VocabularyFileCategoricalColumnTest(test.TestCase):
values=(12, 24, 36),
dense_shape=(2, 2))
with self.assertRaisesRegexp(ValueError, 'dtype must be compatible'):
- # pylint: disable=protected-access
column._get_sparse_tensors(fc._LazyBuilder({'aaa': inputs}))
- # pylint: enable=protected-access
def test_invalid_input_dtype_string(self):
column = fc.categorical_column_with_vocabulary_file(
@@ -1359,9 +1342,7 @@ class VocabularyFileCategoricalColumnTest(test.TestCase):
values=('omar', 'stringer', 'marlo'),
dense_shape=(2, 2))
with self.assertRaisesRegexp(ValueError, 'dtype must be compatible'):
- # pylint: disable=protected-access
column._get_sparse_tensors(fc._LazyBuilder({'aaa': inputs}))
- # pylint: enable=protected-access
def test_get_sparse_tensors(self):
column = fc.categorical_column_with_vocabulary_file(
@@ -1372,10 +1353,8 @@ class VocabularyFileCategoricalColumnTest(test.TestCase):
indices=((0, 0), (1, 0), (1, 1)),
values=('marlo', 'skywalker', 'omar'),
dense_shape=(2, 2))
- # pylint: disable=protected-access
id_weight_pair = column._get_sparse_tensors(
fc._LazyBuilder({'aaa': inputs}))
- # pylint: enable=protected-access
self.assertIsNone(id_weight_pair.weight_tensor)
with _initialized_session():
_assert_sparse_tensor_value(
@@ -1386,16 +1365,33 @@ class VocabularyFileCategoricalColumnTest(test.TestCase):
dense_shape=inputs.dense_shape),
id_weight_pair.id_tensor.eval())
+ def test_transform_feature(self):
+ column = fc.categorical_column_with_vocabulary_file(
+ key='aaa',
+ vocabulary_file=self._wire_vocabulary_file_name,
+ vocabulary_size=self._wire_vocabulary_size)
+ inputs = sparse_tensor.SparseTensorValue(
+ indices=((0, 0), (1, 0), (1, 1)),
+ values=('marlo', 'skywalker', 'omar'),
+ dense_shape=(2, 2))
+ id_tensor = fc._transform_features({'aaa': inputs}, [column])[column]
+ with _initialized_session():
+ _assert_sparse_tensor_value(self,
+ sparse_tensor.SparseTensorValue(
+ indices=inputs.indices,
+ values=np.array(
+ (2, -1, 0), dtype=np.int64),
+ dense_shape=inputs.dense_shape),
+ id_tensor.eval())
+
def test_get_sparse_tensors_dense_input(self):
column = fc.categorical_column_with_vocabulary_file(
key='aaa',
vocabulary_file=self._wire_vocabulary_file_name,
vocabulary_size=self._wire_vocabulary_size)
- # pylint: disable=protected-access
id_weight_pair = column._get_sparse_tensors(fc._LazyBuilder({
'aaa': (('marlo', ''), ('skywalker', 'omar'))
}))
- # pylint: enable=protected-access
self.assertIsNone(id_weight_pair.weight_tensor)
with _initialized_session():
_assert_sparse_tensor_value(
@@ -1416,10 +1412,8 @@ class VocabularyFileCategoricalColumnTest(test.TestCase):
indices=((0, 0), (1, 0), (1, 1)),
values=('marlo', 'skywalker', 'omar'),
dense_shape=(2, 2))
- # pylint: disable=protected-access
id_weight_pair = column._get_sparse_tensors(
fc._LazyBuilder({'aaa': inputs}))
- # pylint: enable=protected-access
self.assertIsNone(id_weight_pair.weight_tensor)
with _initialized_session():
_assert_sparse_tensor_value(
@@ -1440,10 +1434,8 @@ class VocabularyFileCategoricalColumnTest(test.TestCase):
indices=((0, 0), (1, 0), (1, 1), (1, 2)),
values=('marlo', 'skywalker', 'omar', 'heisenberg'),
dense_shape=(2, 3))
- # pylint: disable=protected-access
id_weight_pair = column._get_sparse_tensors(
fc._LazyBuilder({'aaa': inputs}))
- # pylint: enable=protected-access
self.assertIsNone(id_weight_pair.weight_tensor)
with _initialized_session():
_assert_sparse_tensor_value(
@@ -1466,10 +1458,8 @@ class VocabularyFileCategoricalColumnTest(test.TestCase):
indices=((0, 0), (1, 0), (1, 1)),
values=('marlo', 'skywalker', 'omar'),
dense_shape=(2, 2))
- # pylint: disable=protected-access
id_weight_pair = column._get_sparse_tensors(
fc._LazyBuilder({'aaa': inputs}))
- # pylint: enable=protected-access
self.assertIsNone(id_weight_pair.weight_tensor)
with _initialized_session():
_assert_sparse_tensor_value(
@@ -1490,10 +1480,8 @@ class VocabularyFileCategoricalColumnTest(test.TestCase):
indices=((0, 0), (1, 0), (1, 1), (2, 2)),
values=(11, 100, 30, 22),
dense_shape=(3, 3))
- # pylint: disable=protected-access
id_weight_pair = column._get_sparse_tensors(
fc._LazyBuilder({'aaa': inputs}))
- # pylint: enable=protected-access
self.assertIsNone(id_weight_pair.weight_tensor)
with _initialized_session():
_assert_sparse_tensor_value(
@@ -1512,11 +1500,9 @@ class VocabularyFileCategoricalColumnTest(test.TestCase):
vocabulary_size=self._warriors_vocabulary_size,
dtype=dtypes.int32,
default_value=default_value)
- # pylint: disable=protected-access
id_weight_pair = column._get_sparse_tensors(fc._LazyBuilder({
'aaa': ((11, -1, -1), (100, 30, -1), (-1, -1, 22))
}))
- # pylint: enable=protected-access
self.assertIsNone(id_weight_pair.weight_tensor)
with _initialized_session():
_assert_sparse_tensor_value(
@@ -1538,10 +1524,8 @@ class VocabularyFileCategoricalColumnTest(test.TestCase):
indices=((0, 0), (1, 0), (1, 1), (2, 2)),
values=(11, 100, 30, 22),
dense_shape=(3, 3))
- # pylint: disable=protected-access
id_weight_pair = column._get_sparse_tensors(
fc._LazyBuilder({'aaa': inputs}))
- # pylint: enable=protected-access
self.assertIsNone(id_weight_pair.weight_tensor)
with _initialized_session():
_assert_sparse_tensor_value(
@@ -1584,34 +1568,28 @@ class VocabularyListCategoricalColumnTest(test.TestCase):
column = fc.categorical_column_with_vocabulary_list(
key='aaa', vocabulary_list=('omar', 'stringer', 'marlo'))
self.assertEqual('aaa', column.name)
- # pylint: disable=protected-access
self.assertEqual(3, column._num_buckets)
self.assertEqual({
'aaa': parsing_ops.VarLenFeature(dtypes.string)
}, column._parse_example_config)
- # pylint: enable=protected-access
def test_defaults_int(self):
column = fc.categorical_column_with_vocabulary_list(
key='aaa', vocabulary_list=(12, 24, 36))
self.assertEqual('aaa', column.name)
- # pylint: disable=protected-access
self.assertEqual(3, column._num_buckets)
self.assertEqual({
'aaa': parsing_ops.VarLenFeature(dtypes.int64)
}, column._parse_example_config)
- # pylint: enable=protected-access
def test_all_constructor_args(self):
column = fc.categorical_column_with_vocabulary_list(
key='aaa', vocabulary_list=(12, 24, 36), dtype=dtypes.int32,
default_value=-99)
- # pylint: disable=protected-access
self.assertEqual(3, column._num_buckets)
self.assertEqual({
'aaa': parsing_ops.VarLenFeature(dtypes.int32)
}, column._parse_example_config)
- # pylint: enable=protected-access
def test_deep_copy(self):
"""Tests deepcopy of categorical_column_with_hash_bucket."""
@@ -1619,12 +1597,10 @@ class VocabularyListCategoricalColumnTest(test.TestCase):
key='aaa', vocabulary_list=(12, 24, 36), dtype=dtypes.int32)
for column in (original, copy.deepcopy(original)):
self.assertEqual('aaa', column.name)
- # pylint: disable=protected-access
self.assertEqual(3, column._num_buckets)
self.assertEqual({
'aaa': parsing_ops.VarLenFeature(dtypes.int32)
}, column._parse_example_config)
- # pylint: enable=protected-access
def test_invalid_dtype(self):
with self.assertRaisesRegexp(ValueError, 'dtype must be string or integer'):
@@ -1677,9 +1653,7 @@ class VocabularyListCategoricalColumnTest(test.TestCase):
values=(12, 24, 36),
dense_shape=(2, 2))
with self.assertRaisesRegexp(ValueError, 'dtype must be compatible'):
- # pylint: disable=protected-access
column._get_sparse_tensors(fc._LazyBuilder({'aaa': inputs}))
- # pylint: enable=protected-access
def test_invalid_input_dtype_string(self):
column = fc.categorical_column_with_vocabulary_list(
@@ -1690,9 +1664,7 @@ class VocabularyListCategoricalColumnTest(test.TestCase):
values=('omar', 'stringer', 'marlo'),
dense_shape=(2, 2))
with self.assertRaisesRegexp(ValueError, 'dtype must be compatible'):
- # pylint: disable=protected-access
column._get_sparse_tensors(fc._LazyBuilder({'aaa': inputs}))
- # pylint: enable=protected-access
def test_get_sparse_tensors(self):
column = fc.categorical_column_with_vocabulary_list(
@@ -1702,10 +1674,8 @@ class VocabularyListCategoricalColumnTest(test.TestCase):
indices=((0, 0), (1, 0), (1, 1)),
values=('marlo', 'skywalker', 'omar'),
dense_shape=(2, 2))
- # pylint: disable=protected-access
id_weight_pair = column._get_sparse_tensors(
fc._LazyBuilder({'aaa': inputs}))
- # pylint: enable=protected-access
self.assertIsNone(id_weight_pair.weight_tensor)
with _initialized_session():
_assert_sparse_tensor_value(
@@ -1720,11 +1690,9 @@ class VocabularyListCategoricalColumnTest(test.TestCase):
column = fc.categorical_column_with_vocabulary_list(
key='aaa',
vocabulary_list=('omar', 'stringer', 'marlo'))
- # pylint: disable=protected-access
id_weight_pair = column._get_sparse_tensors(fc._LazyBuilder({
'aaa': (('marlo', ''), ('skywalker', 'omar'))
}))
- # pylint: enable=protected-access
self.assertIsNone(id_weight_pair.weight_tensor)
with _initialized_session():
_assert_sparse_tensor_value(
@@ -1744,10 +1712,8 @@ class VocabularyListCategoricalColumnTest(test.TestCase):
indices=((0, 0), (1, 0), (1, 1)),
values=('marlo', 'skywalker', 'omar'),
dense_shape=(2, 2))
- # pylint: disable=protected-access
id_weight_pair = column._get_sparse_tensors(
fc._LazyBuilder({'aaa': inputs}))
- # pylint: enable=protected-access
self.assertIsNone(id_weight_pair.weight_tensor)
with _initialized_session():
_assert_sparse_tensor_value(
@@ -1767,10 +1733,8 @@ class VocabularyListCategoricalColumnTest(test.TestCase):
indices=((0, 0), (1, 0), (1, 1), (2, 2)),
values=np.array((11, 100, 30, 22), dtype=np.int32),
dense_shape=(3, 3))
- # pylint: disable=protected-access
id_weight_pair = column._get_sparse_tensors(
fc._LazyBuilder({'aaa': inputs}))
- # pylint: enable=protected-access
self.assertIsNone(id_weight_pair.weight_tensor)
with _initialized_session():
_assert_sparse_tensor_value(
@@ -1788,13 +1752,11 @@ class VocabularyListCategoricalColumnTest(test.TestCase):
vocabulary_list=np.array((30, 35, 11, 23, 22), dtype=np.int32),
dtype=dtypes.int32,
default_value=default_value)
- # pylint: disable=protected-access
id_weight_pair = column._get_sparse_tensors(fc._LazyBuilder({
'aaa': np.array(
((11, -1, -1), (100, 30, -1), (-1, -1, 22)),
dtype=np.int32)
}))
- # pylint: enable=protected-access
self.assertIsNone(id_weight_pair.weight_tensor)
with _initialized_session():
_assert_sparse_tensor_value(
@@ -1834,24 +1796,20 @@ class IdentityCategoricalColumnTest(test.TestCase):
def test_constructor(self):
column = fc.categorical_column_with_identity(key='aaa', num_buckets=3)
self.assertEqual('aaa', column.name)
- # pylint: disable=protected-access
self.assertEqual(3, column._num_buckets)
self.assertEqual({
'aaa': parsing_ops.VarLenFeature(dtypes.int64)
}, column._parse_example_config)
- # pylint: enable=protected-access
def test_deep_copy(self):
"""Tests deepcopy of categorical_column_with_hash_bucket."""
original = fc.categorical_column_with_identity(key='aaa', num_buckets=3)
for column in (original, copy.deepcopy(original)):
self.assertEqual('aaa', column.name)
- # pylint: disable=protected-access
self.assertEqual(3, column._num_buckets)
self.assertEqual({
'aaa': parsing_ops.VarLenFeature(dtypes.int64)
}, column._parse_example_config)
- # pylint: enable=protected-access
def test_invalid_num_buckets_zero(self):
with self.assertRaisesRegexp(ValueError, 'num_buckets 0 < 1'):
@@ -1878,9 +1836,7 @@ class IdentityCategoricalColumnTest(test.TestCase):
values=('omar', 'stringer', 'marlo'),
dense_shape=(2, 2))
with self.assertRaisesRegexp(ValueError, 'Invalid input, not integer'):
- # pylint: disable=protected-access
column._get_sparse_tensors(fc._LazyBuilder({'aaa': inputs}))
- # pylint: enable=protected-access
def test_get_sparse_tensors(self):
column = fc.categorical_column_with_identity(key='aaa', num_buckets=3)
@@ -1888,10 +1844,8 @@ class IdentityCategoricalColumnTest(test.TestCase):
indices=((0, 0), (1, 0), (1, 1)),
values=(0, 1, 0),
dense_shape=(2, 2))
- # pylint: disable=protected-access
id_weight_pair = column._get_sparse_tensors(
fc._LazyBuilder({'aaa': inputs}))
- # pylint: enable=protected-access
self.assertIsNone(id_weight_pair.weight_tensor)
with _initialized_session():
_assert_sparse_tensor_value(
@@ -1904,11 +1858,9 @@ class IdentityCategoricalColumnTest(test.TestCase):
def test_get_sparse_tensors_dense_input(self):
column = fc.categorical_column_with_identity(key='aaa', num_buckets=3)
- # pylint: disable=protected-access
id_weight_pair = column._get_sparse_tensors(fc._LazyBuilder({
'aaa': ((0, -1), (1, 0))
}))
- # pylint: enable=protected-access
self.assertIsNone(id_weight_pair.weight_tensor)
with _initialized_session():
_assert_sparse_tensor_value(
@@ -1925,10 +1877,8 @@ class IdentityCategoricalColumnTest(test.TestCase):
indices=((0, 0), (1, 0), (1, 1)),
values=(1, -1, 0),
dense_shape=(2, 2))
- # pylint: disable=protected-access
id_weight_pair = column._get_sparse_tensors(
fc._LazyBuilder({'aaa': inputs}))
- # pylint: enable=protected-access
self.assertIsNone(id_weight_pair.weight_tensor)
with _initialized_session():
with self.assertRaisesRegexp(
@@ -1941,10 +1891,8 @@ class IdentityCategoricalColumnTest(test.TestCase):
indices=((0, 0), (1, 0), (1, 1)),
values=(1, 99, 0),
dense_shape=(2, 2))
- # pylint: disable=protected-access
id_weight_pair = column._get_sparse_tensors(
fc._LazyBuilder({'aaa': inputs}))
- # pylint: enable=protected-access
self.assertIsNone(id_weight_pair.weight_tensor)
with _initialized_session():
with self.assertRaisesRegexp(
@@ -1958,10 +1906,8 @@ class IdentityCategoricalColumnTest(test.TestCase):
indices=((0, 0), (1, 0), (1, 1)),
values=(1, -1, 99),
dense_shape=(2, 2))
- # pylint: disable=protected-access
id_weight_pair = column._get_sparse_tensors(
fc._LazyBuilder({'aaa': inputs}))
- # pylint: enable=protected-access
self.assertIsNone(id_weight_pair.weight_tensor)
with _initialized_session():
_assert_sparse_tensor_value(
@@ -1982,10 +1928,8 @@ class IdentityCategoricalColumnTest(test.TestCase):
indices=input_indices,
values=input_values,
dense_shape=input_shape)
- # pylint: disable=protected-access
id_weight_pair = column._get_sparse_tensors(
fc._LazyBuilder({'aaa': inputs}))
- # pylint: enable=protected-access
self.assertIsNone(id_weight_pair.weight_tensor)
with _initialized_session():
_assert_sparse_tensor_value(
@@ -2022,5 +1966,66 @@ class IdentityCategoricalColumnTest(test.TestCase):
self.assertAllClose(((1.,), (5.,)), predictions.eval())
+class TransformFeaturesTest(test.TestCase):
+
+ # All transform tests are distributed in column test.
+ # Here we only test multi column case and naming
+ def transform_multi_column(self):
+ bucketized_price = fc.bucketized_column(
+ fc.numeric_column('price'), boundaries=[0, 2, 4, 6])
+ hashed_sparse = fc.categorical_column_with_hash_bucket('wire', 10)
+ with ops.Graph().as_default():
+ features = {
+ 'price': [[-1.], [5.]],
+ 'wire':
+ sparse_tensor.SparseTensor(
+ values=['omar', 'stringer', 'marlo'],
+ indices=[[0, 0], [1, 0], [1, 1]],
+ dense_shape=[2, 2])
+ }
+ transformed = fc._transform_features(features,
+ [bucketized_price, hashed_sparse])
+ with _initialized_session():
+ self.assertIn(bucketized_price.name, transformed[bucketized_price].name)
+ self.assertAllEqual([[0], [3]], transformed[bucketized_price].eval())
+ self.assertIn(hashed_sparse.name, transformed[hashed_sparse].name)
+ self.assertAllEqual([6, 4, 1], transformed[hashed_sparse].values.eval())
+
+ def test_column_order(self):
+ """When the column is both dense and sparse, uses sparse tensors."""
+
+ class _LoggerColumn(fc._FeatureColumn):
+
+ def __init__(self, name):
+ self._name = name
+
+ @property
+ def name(self):
+ return self._name
+
+ def _transform_feature(self, inputs):
+ del inputs
+ self.call_order = call_logger['count']
+ call_logger['count'] += 1
+ return 'Anything'
+
+ @property
+ def _parse_example_config(self):
+ pass
+
+ with ops.Graph().as_default():
+ column1 = _LoggerColumn('1')
+ column2 = _LoggerColumn('2')
+ call_logger = {'count': 0}
+ fc._transform_features({}, [column1, column2])
+ self.assertEqual(0, column1.call_order)
+ self.assertEqual(1, column2.call_order)
+
+ call_logger = {'count': 0}
+ fc._transform_features({}, [column2, column1])
+ self.assertEqual(0, column1.call_order)
+ self.assertEqual(1, column2.call_order)
+
+
if __name__ == '__main__':
test.main()