aboutsummaryrefslogtreecommitdiffhomepage
diff options
context:
space:
mode:
authorGravatar A. Unique TensorFlower <gardener@tensorflow.org>2017-05-11 13:27:35 -0700
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2017-05-11 13:31:04 -0700
commit3c2dc3baaae762b00d90761f47265411f54033b3 (patch)
treea9e2e77aa10473aee7670ee7cddcdfba0b5d9f95
parente1e820efb0e46dafd70d8a776b26962927e64454 (diff)
Renames _parse_example_config to _parse_example_spec, adds tests and other cleanup.
PiperOrigin-RevId: 155788799
-rw-r--r--tensorflow/python/feature_column/feature_column.py64
-rw-r--r--tensorflow/python/feature_column/feature_column_test.py366
2 files changed, 345 insertions, 85 deletions
diff --git a/tensorflow/python/feature_column/feature_column.py b/tensorflow/python/feature_column/feature_column.py
index c9baef0695..f0a7de8668 100644
--- a/tensorflow/python/feature_column/feature_column.py
+++ b/tensorflow/python/feature_column/feature_column.py
@@ -408,7 +408,7 @@ def make_parse_example_spec(feature_columns):
raise ValueError(
'All feature_columns must be _FeatureColumn instances. '
'Given: {}'.format(column))
- config = column._parse_example_config # pylint: disable=protected-access
+ config = column._parse_example_spec # pylint: disable=protected-access
for key, value in six.iteritems(config):
if key in result and value != result[key]:
raise ValueError(
@@ -484,10 +484,8 @@ def embedding_column(
'Embedding of column_name: {}'.format(
categorical_column.name))
if initializer is None:
- # pylint: disable=protected-access
initializer = init_ops.truncated_normal_initializer(
mean=0.0, stddev=1 / math.sqrt(dimension))
- # pylint: enable=protected-access
return _EmbeddingColumn(
categorical_column=categorical_column,
@@ -1197,7 +1195,7 @@ class _FeatureColumn(object):
pass
@abc.abstractproperty
- def _parse_example_config(self):
+ def _parse_example_spec(self):
"""Returns a `tf.Example` parsing spec as dict.
It is used for get_parsing_spec for `tf.parse_example`. Returned spec is a
@@ -1207,11 +1205,11 @@ class _FeatureColumn(object):
Let's say a Feature column depends on raw feature ('raw') and another
`_FeatureColumn` (input_fc). One possible implementation of
- _parse_example_config is as follows:
+ _parse_example_spec is as follows:
```python
spec = {'raw': tf.FixedLenFeature(...)}
- spec.update(input_fc._parse_example_config)
+ spec.update(input_fc._parse_example_spec)
return spec
```
"""
@@ -1428,9 +1426,7 @@ class _LazyBuilder(object):
column = key
logging.debug('Transforming feature_column %s.', column)
- # pylint: disable=protected-access
- transformed = column._transform_feature(self)
- # pylint: enable=protected-access
+ transformed = column._transform_feature(self) # pylint: disable=protected-access
if transformed is None:
raise ValueError('Column {} is not supported.'.format(column.name))
self._feature_tensors[column] = transformed
@@ -1529,7 +1525,7 @@ class _NumericColumn(_DenseColumn,
return self.key
@property
- def _parse_example_config(self):
+ def _parse_example_spec(self):
return {
self.key:
parsing_ops.FixedLenFeature(self.shape, self.dtype,
@@ -1582,8 +1578,8 @@ class _BucketizedColumn(_DenseColumn, _CategoricalColumn,
return '{}_bucketized'.format(self.source_column.name)
@property
- def _parse_example_config(self):
- return self.source_column._parse_example_config # pylint: disable=protected-access
+ def _parse_example_spec(self):
+ return self.source_column._parse_example_spec # pylint: disable=protected-access
def _transform_feature(self, inputs):
source_tensor = inputs.get(self.source_column)
@@ -1655,10 +1651,8 @@ class _EmbeddingColumn(
return self._name
@property
- def _parse_example_config(self):
- # pylint: disable=protected-access
- return self.categorical_column._parse_example_config
- # pylint: enable=protected-access
+ def _parse_example_spec(self):
+ return self.categorical_column._parse_example_spec # pylint: disable=protected-access
def _transform_feature(self, inputs):
return inputs.get(self.categorical_column)
@@ -1671,19 +1665,15 @@ class _EmbeddingColumn(
def _get_dense_tensor(self, inputs, weight_collections=None, trainable=None):
# Get sparse IDs and weights.
- # pylint: disable=protected-access
- sparse_tensors = self.categorical_column._get_sparse_tensors(
+ sparse_tensors = self.categorical_column._get_sparse_tensors( # pylint: disable=protected-access
inputs, weight_collections=weight_collections, trainable=trainable)
- # pylint: enable=protected-access
sparse_ids = sparse_tensors.id_tensor
sparse_weights = sparse_tensors.weight_tensor
# Create embedding weight, and restore from checkpoint if necessary.
embedding_weights = variable_scope.get_variable(
name='embedding_weights',
- # pylint: disable=protected-access
- shape=(self.categorical_column._num_buckets, self.dimension),
- # pylint: enable=protected-access
+ shape=(self.categorical_column._num_buckets, self.dimension), # pylint: disable=protected-access
dtype=dtypes.float32,
initializer=self.initializer,
trainable=self.trainable and trainable,
@@ -1691,9 +1681,7 @@ class _EmbeddingColumn(
if self.ckpt_to_load_from is not None:
to_restore = embedding_weights
if isinstance(to_restore, variables.PartitionedVariable):
- # pylint: disable=protected-access
- to_restore = to_restore._get_variable_list()
- # pylint: enable=protected-access
+ to_restore = to_restore._get_variable_list() # pylint: disable=protected-access
checkpoint_utils.init_from_checkpoint(self.ckpt_to_load_from, {
self.tensor_name_in_ckpt: to_restore
})
@@ -1824,7 +1812,7 @@ class _HashedCategoricalColumn(
return self.key
@property
- def _parse_example_config(self):
+ def _parse_example_spec(self):
return {self.key: parsing_ops.VarLenFeature(self.dtype)}
def _transform_feature(self, inputs):
@@ -1875,7 +1863,7 @@ class _VocabularyFileCategoricalColumn(
return self.key
@property
- def _parse_example_config(self):
+ def _parse_example_spec(self):
return {self.key: parsing_ops.VarLenFeature(self.dtype)}
def _transform_feature(self, inputs):
@@ -1927,7 +1915,7 @@ class _VocabularyListCategoricalColumn(
return self.key
@property
- def _parse_example_config(self):
+ def _parse_example_spec(self):
return {self.key: parsing_ops.VarLenFeature(self.dtype)}
def _transform_feature(self, inputs):
@@ -1978,7 +1966,7 @@ class _IdentityCategoricalColumn(
return self.key
@property
- def _parse_example_config(self):
+ def _parse_example_spec(self):
return {self.key: parsing_ops.VarLenFeature(dtypes.int64)}
def _transform_feature(self, inputs):
@@ -2041,10 +2029,8 @@ class _WeightedCategoricalColumn(
self.categorical_column.name, self.weight_column_name)
@property
- def _parse_example_config(self):
- # pylint: disable=protected-access
- config = self.categorical_column._parse_example_config
- # pylint: enable=protected-access
+ def _parse_example_spec(self):
+ config = self.categorical_column._parse_example_spec # pylint: disable=protected-access
if self.weight_column_name in config:
raise ValueError('Parse config {} already exists for {}.'.format(
config[self.weight_column_name], self.weight_column_name))
@@ -2053,9 +2039,7 @@ class _WeightedCategoricalColumn(
@property
def _num_buckets(self):
- # pylint: disable=protected-access
- return self.categorical_column._num_buckets
- # pylint: enable=protected-access
+ return self.categorical_column._num_buckets # pylint: disable=protected-access
def _transform_feature(self, inputs):
weight_tensor = inputs.get(self.weight_column_name)
@@ -2099,11 +2083,11 @@ class _CrossedColumn(
return '_X_'.join(sorted(feature_names))
@property
- def _parse_example_config(self):
+ def _parse_example_spec(self):
config = {}
for key in self.keys:
if isinstance(key, _FeatureColumn):
- config.update(key._parse_example_config) # pylint: disable=protected-access
+ config.update(key._parse_example_spec) # pylint: disable=protected-access
else: # key must be a string
config.update({key: parsing_ops.VarLenFeature(dtypes.string)})
return config
@@ -2348,8 +2332,8 @@ class _IndicatorColumn(_DenseColumn,
return math_ops.reduce_sum(one_hot_id_tensor, axis=[1])
@property
- def _parse_example_config(self):
- return self.categorical_column._parse_example_config # pylint: disable=protected-access
+ def _parse_example_spec(self):
+ return self.categorical_column._parse_example_spec # pylint: disable=protected-access
@property
def _variable_shape(self):
diff --git a/tensorflow/python/feature_column/feature_column_test.py b/tensorflow/python/feature_column/feature_column_test.py
index b09c01d266..977ab81ff6 100644
--- a/tensorflow/python/feature_column/feature_column_test.py
+++ b/tensorflow/python/feature_column/feature_column_test.py
@@ -65,7 +65,7 @@ class LazyColumnTest(test.TestCase):
return cache.get('a')
@property
- def _parse_example_config(self):
+ def _parse_example_spec(self):
pass
builder = fc._LazyBuilder(features={'a': [[2], [3.]]})
@@ -88,7 +88,7 @@ class LazyColumnTest(test.TestCase):
return 'Output'
@property
- def _parse_example_config(self):
+ def _parse_example_spec(self):
pass
builder = fc._LazyBuilder(features={'a': [[2], [3.]]})
@@ -108,7 +108,7 @@ class LazyColumnTest(test.TestCase):
return 'Output'
@property
- def _parse_example_config(self):
+ def _parse_example_spec(self):
pass
features = {'a': [[2], [3.]]}
@@ -135,7 +135,7 @@ class LazyColumnTest(test.TestCase):
pass
@property
- def _parse_example_config(self):
+ def _parse_example_spec(self):
pass
builder = fc._LazyBuilder(features={'a': [[2], [3.]]})
@@ -222,11 +222,11 @@ class NumericColumnTest(test.TestCase):
a = fc.numeric_column('aaa', shape=[2, 3], default_value=2.)
self.assertEqual(((2., 2., 2.), (2., 2., 2.)), a.default_value)
- def test_parse_config(self):
+ def test_parse_spec(self):
a = fc.numeric_column('aaa', shape=[2, 3], dtype=dtypes.int32)
self.assertEqual({
'aaa': parsing_ops.FixedLenFeature((2, 3), dtype=dtypes.int32)
- }, a._parse_example_config)
+ }, a._parse_example_spec)
def test_parse_example_no_default_value(self):
price = fc.numeric_column('price', shape=[2])
@@ -238,7 +238,7 @@ class NumericColumnTest(test.TestCase):
}))
features = parsing_ops.parse_example(
serialized=[data.SerializeToString()],
- features=price._parse_example_config)
+ features=fc.make_parse_example_spec([price]))
self.assertIn('price', features)
with self.test_session():
self.assertAllEqual([[20., 110.]], features['price'].eval())
@@ -260,7 +260,7 @@ class NumericColumnTest(test.TestCase):
features = parsing_ops.parse_example(
serialized=[data.SerializeToString(),
no_data.SerializeToString()],
- features=price._parse_example_config)
+ features=fc.make_parse_example_spec([price]))
self.assertIn('price', features)
with self.test_session():
self.assertAllEqual([[20., 110.], [11., 11.]], features['price'].eval())
@@ -362,12 +362,12 @@ class BucketizedColumnTest(test.TestCase):
b = fc.bucketized_column(a, boundaries=[0, 1])
self.assertEqual('aaa_bucketized', b.name)
- def test_parse_config(self):
+ def test_parse_spec(self):
a = fc.numeric_column('aaa', shape=[2], dtype=dtypes.int32)
b = fc.bucketized_column(a, boundaries=[0, 1])
self.assertEqual({
'aaa': parsing_ops.FixedLenFeature((2,), dtype=dtypes.int32)
- }, b._parse_example_config)
+ }, b._parse_example_spec)
def test_variable_shape(self):
a = fc.numeric_column('aaa', shape=[2], dtype=dtypes.int32)
@@ -392,7 +392,7 @@ class BucketizedColumnTest(test.TestCase):
}))
features = parsing_ops.parse_example(
serialized=[data.SerializeToString()],
- features=bucketized_price._parse_example_config)
+ features=fc.make_parse_example_spec([bucketized_price]))
self.assertIn('price', features)
with self.test_session():
self.assertAllEqual([[20., 110.]], features['price'].eval())
@@ -584,17 +584,38 @@ class HashedCategoricalColumnTest(test.TestCase):
self.assertEqual(10, column._num_buckets)
self.assertEqual(dtypes.string, column.dtype)
- def test_parse_config(self):
+ def test_parse_spec_string(self):
a = fc.categorical_column_with_hash_bucket('aaa', 10)
self.assertEqual({
'aaa': parsing_ops.VarLenFeature(dtypes.string)
- }, a._parse_example_config)
+ }, a._parse_example_spec)
- def test_parse_config_int(self):
+ def test_parse_spec_int(self):
a = fc.categorical_column_with_hash_bucket('aaa', 10, dtype=dtypes.int32)
self.assertEqual({
'aaa': parsing_ops.VarLenFeature(dtypes.int32)
- }, a._parse_example_config)
+ }, a._parse_example_spec)
+
+ def test_parse_example(self):
+ a = fc.categorical_column_with_hash_bucket('aaa', 10)
+ data = example_pb2.Example(features=feature_pb2.Features(
+ feature={
+ 'aaa':
+ feature_pb2.Feature(bytes_list=feature_pb2.BytesList(
+ value=[b'omar', b'stringer']))
+ }))
+ features = parsing_ops.parse_example(
+ serialized=[data.SerializeToString()],
+ features=fc.make_parse_example_spec([a]))
+ self.assertIn('aaa', features)
+ with self.test_session():
+ _assert_sparse_tensor_value(
+ self,
+ sparse_tensor.SparseTensorValue(
+ indices=[[0, 0], [0, 1]],
+ values=np.array([b'omar', b'stringer'], dtype=np.object_),
+ dense_shape=[1, 2]),
+ features['aaa'].eval())
def test_strings_should_be_hashed(self):
hashed_sparse = fc.categorical_column_with_hash_bucket('wire', 10)
@@ -797,14 +818,14 @@ class CrossedColumnTest(test.TestCase):
crossed2 = fc.crossed_column([crossed1, 'd1', b], 10)
self.assertEqual('a_bucketized_X_c_X_d1_X_d2', crossed2.name)
- def test_parse_config(self):
+ def test_parse_spec(self):
a = fc.numeric_column('a', shape=[2], dtype=dtypes.int32)
b = fc.bucketized_column(a, boundaries=[0, 1])
crossed = fc.crossed_column([b, 'c'], 10)
self.assertEqual({
'a': parsing_ops.FixedLenFeature((2,), dtype=dtypes.int32),
'c': parsing_ops.VarLenFeature(dtypes.string),
- }, crossed._parse_example_config)
+ }, crossed._parse_example_spec)
def test_num_buckets(self):
a = fc.numeric_column('a', shape=[2], dtype=dtypes.int32)
@@ -837,7 +858,7 @@ class CrossedColumnTest(test.TestCase):
}))
features = parsing_ops.parse_example(
serialized=[data.SerializeToString()],
- features=price_cross_wire._parse_example_config)
+ features=fc.make_parse_example_spec([price_cross_wire]))
self.assertIn('price', features)
self.assertIn('wire', features)
with self.test_session():
@@ -848,6 +869,29 @@ class CrossedColumnTest(test.TestCase):
self.assertAllEqual([b'omar', b'stringer'], wire_sparse.values.eval())
self.assertAllEqual([1, 2], wire_sparse.dense_shape.eval())
+ def test_transform_feature(self):
+ price = fc.numeric_column('price', shape=[2])
+ bucketized_price = fc.bucketized_column(price, boundaries=[0, 50])
+ hash_bucket_size = 10
+ price_cross_wire = fc.crossed_column(
+ [bucketized_price, 'wire'], hash_bucket_size)
+ features = {
+ 'price': constant_op.constant([[1., 2.], [5., 6.]]),
+ 'wire': sparse_tensor.SparseTensor(
+ values=['omar', 'stringer', 'marlo'],
+ indices=[[0, 0], [1, 0], [1, 1]],
+ dense_shape=[2, 2]),
+ }
+ outputs = fc._transform_features(features, [price_cross_wire])
+ output = outputs[price_cross_wire]
+ with self.test_session() as sess:
+ output_val = sess.run(output)
+ self.assertAllEqual(
+ [[0, 0], [0, 1], [1, 0], [1, 1], [1, 2], [1, 3]], output_val.indices)
+ for val in output_val.values:
+ self.assertIn(val, list(range(hash_bucket_size)))
+ self.assertAllEqual([2, 4], output_val.dense_shape)
+
def test_get_sparse_tensors(self):
a = fc.numeric_column('a', dtype=dtypes.int32, shape=(2,))
b = fc.bucketized_column(a, boundaries=(0, 1))
@@ -947,7 +991,7 @@ class CrossedColumnTest(test.TestCase):
return 'test_column'
@property
- def _parse_example_config(self):
+ def _parse_example_spec(self):
return {
self.name: parsing_ops.VarLenFeature(dtypes.int32),
'{}_weights'.format(self.name): parsing_ops.VarLenFeature(
@@ -1020,7 +1064,7 @@ class MakeLinearModelTest(test.TestCase):
pass
@property
- def _parse_example_config(self):
+ def _parse_example_spec(self):
pass
with self.assertRaisesRegexp(
@@ -1102,7 +1146,7 @@ class MakeLinearModelTest(test.TestCase):
return 'dense_and_sparse_column'
@property
- def _parse_example_config(self):
+ def _parse_example_spec(self):
return {self.name: parsing_ops.VarLenFeature(self.dtype)}
def _transform_feature(self, inputs):
@@ -1473,11 +1517,11 @@ class MakeParseExampleSpecTest(test.TestCase):
class _TestFeatureColumn(
fc._FeatureColumn,
- collections.namedtuple('_TestFeatureColumn', ['parse_config'])):
+ collections.namedtuple('_TestFeatureColumn', ['parse_spec'])):
@property
- def _parse_example_config(self):
- return self.parse_config
+ def _parse_example_spec(self):
+ return self.parse_spec
def test_no_feature_columns(self):
actual = fc.make_parse_example_spec([])
@@ -1583,7 +1627,7 @@ class VocabularyFileCategoricalColumnTest(test.TestCase):
self.assertEqual(3, column._num_buckets)
self.assertEqual({
'aaa': parsing_ops.VarLenFeature(dtypes.string)
- }, column._parse_example_config)
+ }, column._parse_example_spec)
def test_all_constructor_args(self):
column = fc.categorical_column_with_vocabulary_file(
@@ -1592,7 +1636,7 @@ class VocabularyFileCategoricalColumnTest(test.TestCase):
self.assertEqual(7, column._num_buckets)
self.assertEqual({
'aaa': parsing_ops.VarLenFeature(dtypes.int32)
- }, column._parse_example_config)
+ }, column._parse_example_spec)
def test_deep_copy(self):
original = fc.categorical_column_with_vocabulary_file(
@@ -1603,7 +1647,7 @@ class VocabularyFileCategoricalColumnTest(test.TestCase):
self.assertEqual(7, column._num_buckets)
self.assertEqual({
'aaa': parsing_ops.VarLenFeature(dtypes.int32)
- }, column._parse_example_config)
+ }, column._parse_example_spec)
def test_vocabulary_file_none(self):
with self.assertRaisesRegexp(ValueError, 'Missing vocabulary_file'):
@@ -1703,6 +1747,28 @@ class VocabularyFileCategoricalColumnTest(test.TestCase):
with self.assertRaisesRegexp(ValueError, 'dtype must be compatible'):
column._get_sparse_tensors(fc._LazyBuilder({'aaa': inputs}))
+ def test_parse_example(self):
+ a = fc.categorical_column_with_vocabulary_file(
+ key='aaa', vocabulary_file='path_to_file', vocabulary_size=3)
+ data = example_pb2.Example(features=feature_pb2.Features(
+ feature={
+ 'aaa':
+ feature_pb2.Feature(bytes_list=feature_pb2.BytesList(
+ value=[b'omar', b'stringer']))
+ }))
+ features = parsing_ops.parse_example(
+ serialized=[data.SerializeToString()],
+ features=fc.make_parse_example_spec([a]))
+ self.assertIn('aaa', features)
+ with self.test_session():
+ _assert_sparse_tensor_value(
+ self,
+ sparse_tensor.SparseTensorValue(
+ indices=[[0, 0], [0, 1]],
+ values=np.array([b'omar', b'stringer'], dtype=np.object_),
+ dense_shape=[1, 2]),
+ features['aaa'].eval())
+
def test_get_sparse_tensors(self):
column = fc.categorical_column_with_vocabulary_file(
key='aaa',
@@ -1946,7 +2012,7 @@ class VocabularyListCategoricalColumnTest(test.TestCase):
self.assertEqual(3, column._num_buckets)
self.assertEqual({
'aaa': parsing_ops.VarLenFeature(dtypes.string)
- }, column._parse_example_config)
+ }, column._parse_example_spec)
def test_defaults_int(self):
column = fc.categorical_column_with_vocabulary_list(
@@ -1955,7 +2021,7 @@ class VocabularyListCategoricalColumnTest(test.TestCase):
self.assertEqual(3, column._num_buckets)
self.assertEqual({
'aaa': parsing_ops.VarLenFeature(dtypes.int64)
- }, column._parse_example_config)
+ }, column._parse_example_spec)
def test_all_constructor_args(self):
column = fc.categorical_column_with_vocabulary_list(
@@ -1964,7 +2030,7 @@ class VocabularyListCategoricalColumnTest(test.TestCase):
self.assertEqual(3, column._num_buckets)
self.assertEqual({
'aaa': parsing_ops.VarLenFeature(dtypes.int32)
- }, column._parse_example_config)
+ }, column._parse_example_spec)
def test_deep_copy(self):
original = fc.categorical_column_with_vocabulary_list(
@@ -1974,7 +2040,7 @@ class VocabularyListCategoricalColumnTest(test.TestCase):
self.assertEqual(3, column._num_buckets)
self.assertEqual({
'aaa': parsing_ops.VarLenFeature(dtypes.int32)
- }, column._parse_example_config)
+ }, column._parse_example_spec)
def test_invalid_dtype(self):
with self.assertRaisesRegexp(ValueError, 'dtype must be string or integer'):
@@ -2040,6 +2106,50 @@ class VocabularyListCategoricalColumnTest(test.TestCase):
with self.assertRaisesRegexp(ValueError, 'dtype must be compatible'):
column._get_sparse_tensors(fc._LazyBuilder({'aaa': inputs}))
+ def test_parse_example_string(self):
+ a = fc.categorical_column_with_vocabulary_list(
+ key='aaa', vocabulary_list=('omar', 'stringer', 'marlo'))
+ data = example_pb2.Example(features=feature_pb2.Features(
+ feature={
+ 'aaa':
+ feature_pb2.Feature(bytes_list=feature_pb2.BytesList(
+ value=[b'omar', b'stringer']))
+ }))
+ features = parsing_ops.parse_example(
+ serialized=[data.SerializeToString()],
+ features=fc.make_parse_example_spec([a]))
+ self.assertIn('aaa', features)
+ with self.test_session():
+ _assert_sparse_tensor_value(
+ self,
+ sparse_tensor.SparseTensorValue(
+ indices=[[0, 0], [0, 1]],
+ values=np.array([b'omar', b'stringer'], dtype=np.object_),
+ dense_shape=[1, 2]),
+ features['aaa'].eval())
+
+ def test_parse_example_int(self):
+ a = fc.categorical_column_with_vocabulary_list(
+ key='aaa', vocabulary_list=(11, 21, 31))
+ data = example_pb2.Example(features=feature_pb2.Features(
+ feature={
+ 'aaa':
+ feature_pb2.Feature(int64_list=feature_pb2.Int64List(
+ value=[11, 21]))
+ }))
+ features = parsing_ops.parse_example(
+ serialized=[data.SerializeToString()],
+ features=fc.make_parse_example_spec([a]))
+ self.assertIn('aaa', features)
+ with self.test_session():
+ _assert_sparse_tensor_value(
+ self,
+ sparse_tensor.SparseTensorValue(
+ indices=[[0, 0], [0, 1]],
+ values=[11, 21],
+ dense_shape=[1, 2]),
+ features['aaa'].eval())
+
def test_get_sparse_tensors(self):
column = fc.categorical_column_with_vocabulary_list(
key='aaa',
@@ -2060,6 +2170,24 @@ class VocabularyListCategoricalColumnTest(test.TestCase):
dense_shape=inputs.dense_shape),
id_weight_pair.id_tensor.eval())
+ def test_transform_feature(self):
+ column = fc.categorical_column_with_vocabulary_list(
+ key='aaa',
+ vocabulary_list=('omar', 'stringer', 'marlo'))
+ inputs = sparse_tensor.SparseTensorValue(
+ indices=((0, 0), (1, 0), (1, 1)),
+ values=('marlo', 'skywalker', 'omar'),
+ dense_shape=(2, 2))
+ id_tensor = fc._transform_features({'aaa': inputs}, [column])[column]
+ with _initialized_session():
+ _assert_sparse_tensor_value(
+ self,
+ sparse_tensor.SparseTensorValue(
+ indices=inputs.indices,
+ values=np.array((2, -1, 0), dtype=np.int64),
+ dense_shape=inputs.dense_shape),
+ id_tensor.eval())
+
def test_get_sparse_tensors_weight_collections(self):
column = fc.categorical_column_with_vocabulary_list(
key='aaa',
@@ -2188,7 +2316,7 @@ class IdentityCategoricalColumnTest(test.TestCase):
self.assertEqual(3, column._num_buckets)
self.assertEqual({
'aaa': parsing_ops.VarLenFeature(dtypes.int64)
- }, column._parse_example_config)
+ }, column._parse_example_spec)
def test_deep_copy(self):
original = fc.categorical_column_with_identity(key='aaa', num_buckets=3)
@@ -2197,7 +2325,7 @@ class IdentityCategoricalColumnTest(test.TestCase):
self.assertEqual(3, column._num_buckets)
self.assertEqual({
'aaa': parsing_ops.VarLenFeature(dtypes.int64)
- }, column._parse_example_config)
+ }, column._parse_example_spec)
def test_invalid_num_buckets_zero(self):
with self.assertRaisesRegexp(ValueError, 'num_buckets 0 < 1'):
@@ -2226,6 +2354,27 @@ class IdentityCategoricalColumnTest(test.TestCase):
with self.assertRaisesRegexp(ValueError, 'Invalid input, not integer'):
column._get_sparse_tensors(fc._LazyBuilder({'aaa': inputs}))
+ def test_parse_example(self):
+ a = fc.categorical_column_with_identity(key='aaa', num_buckets=30)
+ data = example_pb2.Example(features=feature_pb2.Features(
+ feature={
+ 'aaa':
+ feature_pb2.Feature(int64_list=feature_pb2.Int64List(
+ value=[11, 21]))
+ }))
+ features = parsing_ops.parse_example(
+ serialized=[data.SerializeToString()],
+ features=fc.make_parse_example_spec([a]))
+ self.assertIn('aaa', features)
+ with self.test_session():
+ _assert_sparse_tensor_value(
+ self,
+ sparse_tensor.SparseTensorValue(
+ indices=[[0, 0], [0, 1]],
+ values=np.array([11, 21], dtype=np.int64),
+ dense_shape=[1, 2]),
+ features['aaa'].eval())
+
def test_get_sparse_tensors(self):
column = fc.categorical_column_with_identity(key='aaa', num_buckets=3)
inputs = sparse_tensor.SparseTensorValue(
@@ -2244,6 +2393,22 @@ class IdentityCategoricalColumnTest(test.TestCase):
dense_shape=inputs.dense_shape),
id_weight_pair.id_tensor.eval())
+ def test_transform_feature(self):
+ column = fc.categorical_column_with_identity(key='aaa', num_buckets=3)
+ inputs = sparse_tensor.SparseTensorValue(
+ indices=((0, 0), (1, 0), (1, 1)),
+ values=(0, 1, 0),
+ dense_shape=(2, 2))
+ id_tensor = fc._transform_features({'aaa': inputs}, [column])[column]
+ with _initialized_session():
+ _assert_sparse_tensor_value(
+ self,
+ sparse_tensor.SparseTensorValue(
+ indices=inputs.indices,
+ values=np.array((0, 1, 0), dtype=np.int64),
+ dense_shape=inputs.dense_shape),
+ id_tensor.eval())
+
def test_get_sparse_tensors_weight_collections(self):
column = fc.categorical_column_with_identity(key='aaa', num_buckets=3)
inputs = sparse_tensor.SparseTensorValue(
@@ -2411,7 +2576,7 @@ class TransformFeaturesTest(test.TestCase):
return 'Anything'
@property
- def _parse_example_config(self):
+ def _parse_example_spec(self):
pass
with ops.Graph().as_default():
@@ -2489,7 +2654,7 @@ class IndicatorColumnTest(test.TestCase):
with self.test_session():
self.assertAllEqual([[0., 1., 1., 0.]], output.eval())
- def test_indicator_column_deep_copy(self):
+ def test_deep_copy(self):
a = fc.categorical_column_with_hash_bucket('a', 4)
column = fc.indicator_column(a)
column_copy = copy.deepcopy(column)
@@ -2497,6 +2662,44 @@ class IndicatorColumnTest(test.TestCase):
self.assertEqual(column.name, 'a_indicator')
self.assertEqual(column._variable_shape, [1, 4])
+ def test_parse_example(self):
+ a = fc.categorical_column_with_vocabulary_list(
+ key='aaa', vocabulary_list=('omar', 'stringer', 'marlo'))
+ a_indicator = fc.indicator_column(a)
+ data = example_pb2.Example(features=feature_pb2.Features(
+ feature={
+ 'aaa':
+ feature_pb2.Feature(bytes_list=feature_pb2.BytesList(
+ value=[b'omar', b'stringer']))
+ }))
+ features = parsing_ops.parse_example(
+ serialized=[data.SerializeToString()],
+ features=fc.make_parse_example_spec([a_indicator]))
+ self.assertIn('aaa', features)
+ with self.test_session():
+ _assert_sparse_tensor_value(
+ self,
+ sparse_tensor.SparseTensorValue(
+ indices=[[0, 0], [0, 1]],
+ values=np.array([b'omar', b'stringer'], dtype=np.object_),
+ dense_shape=[1, 2]),
+ features['aaa'].eval())
+
+ def test_transform(self):
+ a = fc.categorical_column_with_vocabulary_list(
+ key='aaa', vocabulary_list=('omar', 'stringer', 'marlo'))
+ a_indicator = fc.indicator_column(a)
+ features = {
+ 'aaa': sparse_tensor.SparseTensorValue(
+ indices=((0, 0), (1, 0), (1, 1)),
+ values=('marlo', 'skywalker', 'omar'),
+ dense_shape=(2, 2))
+ }
+ indicator_tensor = fc._transform_features(
+ features, [a_indicator])[a_indicator]
+ with _initialized_session():
+ self.assertAllEqual([[0, 0, 1], [1, 0, 0]], indicator_tensor.eval())
+
def test_make_linear_model(self):
animal = fc.indicator_column(
fc.categorical_column_with_identity('animal', num_buckets=4))
@@ -2551,7 +2754,7 @@ class EmbeddingColumnTest(test.TestCase):
(embedding_dimension,), embedding_column._variable_shape)
self.assertEqual({
'aaa': parsing_ops.VarLenFeature(dtypes.int64)
- }, embedding_column._parse_example_config)
+ }, embedding_column._parse_example_spec)
def test_all_constructor_args(self):
categorical_column = fc.categorical_column_with_identity(
@@ -2575,7 +2778,7 @@ class EmbeddingColumnTest(test.TestCase):
(embedding_dimension,), embedding_column._variable_shape)
self.assertEqual({
'aaa': parsing_ops.VarLenFeature(dtypes.int64)
- }, embedding_column._parse_example_config)
+ }, embedding_column._parse_example_spec)
def test_deep_copy(self):
categorical_column = fc.categorical_column_with_identity(
@@ -2591,7 +2794,7 @@ class EmbeddingColumnTest(test.TestCase):
self.assertEqual(3, embedding_column.categorical_column._num_buckets)
self.assertEqual({
'aaa': parsing_ops.VarLenFeature(dtypes.int64)
- }, embedding_column.categorical_column._parse_example_config)
+ }, embedding_column.categorical_column._parse_example_spec)
self.assertEqual(embedding_dimension, embedding_column.dimension)
self.assertEqual('my_combiner', embedding_column.combiner)
@@ -2605,7 +2808,7 @@ class EmbeddingColumnTest(test.TestCase):
(embedding_dimension,), embedding_column._variable_shape)
self.assertEqual({
'aaa': parsing_ops.VarLenFeature(dtypes.int64)
- }, embedding_column._parse_example_config)
+ }, embedding_column._parse_example_spec)
def test_invalid_initializer(self):
categorical_column = fc.categorical_column_with_identity(
@@ -2613,6 +2816,45 @@ class EmbeddingColumnTest(test.TestCase):
with self.assertRaisesRegexp(ValueError, 'initializer must be callable'):
fc.embedding_column(categorical_column, dimension=2, initializer='not_fn')
+ def test_parse_example(self):
+ a = fc.categorical_column_with_vocabulary_list(
+ key='aaa', vocabulary_list=('omar', 'stringer', 'marlo'))
+ a_embedded = fc.embedding_column(a, dimension=2)
+ data = example_pb2.Example(features=feature_pb2.Features(
+ feature={
+ 'aaa':
+ feature_pb2.Feature(bytes_list=feature_pb2.BytesList(
+ value=[b'omar', b'stringer']))
+ }))
+ features = parsing_ops.parse_example(
+ serialized=[data.SerializeToString()],
+ features=fc.make_parse_example_spec([a_embedded]))
+ self.assertIn('aaa', features)
+ with self.test_session():
+ _assert_sparse_tensor_value(
+ self,
+ sparse_tensor.SparseTensorValue(
+ indices=[[0, 0], [0, 1]],
+ values=np.array([b'omar', b'stringer'], dtype=np.object_),
+ dense_shape=[1, 2]),
+ features['aaa'].eval())
+
+ def test_transform_feature(self):
+ a = fc.categorical_column_with_identity(key='aaa', num_buckets=3)
+ a_embedded = fc.embedding_column(a, dimension=2)
+ features = {
+ 'aaa': sparse_tensor.SparseTensor(
+ indices=((0, 0), (1, 0), (1, 1)),
+ values=(0, 1, 0),
+ dense_shape=(2, 2))
+ }
+ outputs = fc._transform_features(features, [a, a_embedded])
+ output_a = outputs[a]
+ output_embedded = outputs[a_embedded]
+ with _initialized_session():
+ _assert_sparse_tensor_value(
+ self, output_a.eval(), output_embedded.eval())
+
def test_get_dense_tensor(self):
# Inputs.
vocabulary_size = 3
@@ -3090,7 +3332,7 @@ class WeightedCategoricalColumnTest(test.TestCase):
self.assertEqual({
'ids': parsing_ops.VarLenFeature(dtypes.int64),
'values': parsing_ops.VarLenFeature(dtypes.float32)
- }, column._parse_example_config)
+ }, column._parse_example_spec)
def test_deep_copy(self):
"""Tests deepcopy of categorical_column_with_hash_bucket."""
@@ -3104,7 +3346,7 @@ class WeightedCategoricalColumnTest(test.TestCase):
self.assertEqual({
'ids': parsing_ops.VarLenFeature(dtypes.int64),
'values': parsing_ops.VarLenFeature(dtypes.float32)
- }, column._parse_example_config)
+ }, column._parse_example_spec)
def test_invalid_dtype_none(self):
with self.assertRaisesRegexp(ValueError, 'is not convertible to float'):
@@ -3139,7 +3381,7 @@ class WeightedCategoricalColumnTest(test.TestCase):
fc.weighted_categorical_column(
categorical_column=fc.categorical_column_with_identity(
key='aaa', num_buckets=3),
- weight_column_name='aaa')._parse_example_config()
+ weight_column_name='aaa')._parse_example_spec()
def test_missing_weights(self):
column = fc.weighted_categorical_column(
@@ -3154,7 +3396,41 @@ class WeightedCategoricalColumnTest(test.TestCase):
ValueError, 'values is not in features dictionary'):
fc._transform_features({'ids': inputs}, (column,))
- def test_get_sparse_tensors(self):
+ def test_parse_example(self):
+ a = fc.categorical_column_with_vocabulary_list(
+ key='aaa', vocabulary_list=('omar', 'stringer', 'marlo'))
+ a_weighted = fc.weighted_categorical_column(a, weight_column_name='weights')
+ data = example_pb2.Example(features=feature_pb2.Features(
+ feature={
+ 'aaa':
+ feature_pb2.Feature(bytes_list=feature_pb2.BytesList(
+ value=[b'omar', b'stringer'])),
+ 'weights':
+ feature_pb2.Feature(float_list=feature_pb2.FloatList(
+ value=[1., 10.]))
+ }))
+ features = parsing_ops.parse_example(
+ serialized=[data.SerializeToString()],
+ features=fc.make_parse_example_spec([a_weighted]))
+ self.assertIn('aaa', features)
+ self.assertIn('weights', features)
+ with self.test_session():
+ _assert_sparse_tensor_value(
+ self,
+ sparse_tensor.SparseTensorValue(
+ indices=[[0, 0], [0, 1]],
+ values=np.array([b'omar', b'stringer'], dtype=np.object_),
+ dense_shape=[1, 2]),
+ features['aaa'].eval())
+ _assert_sparse_tensor_value(
+ self,
+ sparse_tensor.SparseTensorValue(
+ indices=[[0, 0], [0, 1]],
+ values=np.array([1., 10.], dtype=np.float32),
+ dense_shape=[1, 2]),
+ features['weights'].eval())
+
+ def test_transform_features(self):
column = fc.weighted_categorical_column(
categorical_column=fc.categorical_column_with_identity(
key='ids', num_buckets=3),
@@ -3187,7 +3463,7 @@ class WeightedCategoricalColumnTest(test.TestCase):
dense_shape=weights.dense_shape),
weight_tensor.eval())
- def test_get_sparse_tensors_dense_input(self):
+ def test_transform_features_dense_input(self):
column = fc.weighted_categorical_column(
categorical_column=fc.categorical_column_with_identity(
key='ids', num_buckets=3),
@@ -3216,7 +3492,7 @@ class WeightedCategoricalColumnTest(test.TestCase):
dense_shape=weights.dense_shape),
weight_tensor.eval())
- def test_get_sparse_tensors_dense_weights(self):
+ def test_transform_features_dense_weights(self):
column = fc.weighted_categorical_column(
categorical_column=fc.categorical_column_with_identity(
key='ids', num_buckets=3),