aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/python/feature_column
diff options
context:
space:
mode:
authorGravatar Jianwei Xie <xiejw@google.com>2018-06-14 21:31:23 -0700
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2018-06-14 21:34:35 -0700
commit284ad32b7f42a835d0cb545061fb354b4f96e0c9 (patch)
treecba13abc2360bb48c5b2a82158e63c7d2c823b7c /tensorflow/python/feature_column
parent3cd4eda38e12351c06d45d0780e16d482491ab95 (diff)
Improves the docstring and comments about feature column library.
PiperOrigin-RevId: 200667467
Diffstat (limited to 'tensorflow/python/feature_column')
-rw-r--r--tensorflow/python/feature_column/feature_column.py162
1 files changed, 130 insertions, 32 deletions
diff --git a/tensorflow/python/feature_column/feature_column.py b/tensorflow/python/feature_column/feature_column.py
index f959b5e484..a58c5aabbe 100644
--- a/tensorflow/python/feature_column/feature_column.py
+++ b/tensorflow/python/feature_column/feature_column.py
@@ -172,7 +172,7 @@ def _internal_input_layer(features,
scope=None):
"""See input_layer. `scope` is a name or variable scope to use."""
- feature_columns = _clean_feature_columns(feature_columns)
+ feature_columns = _normalize_feature_columns(feature_columns)
for column in feature_columns:
if not isinstance(column, _DenseColumn):
raise ValueError(
@@ -350,10 +350,23 @@ def linear_model(features,
prediction itself for linear regression problems.
Note on supported columns: `linear_model` treats categorical columns as
- `indicator_column`s while `input_layer` explicitly requires wrapping each
- of them with an `embedding_column` or an `indicator_column`.
+ `indicator_column`s. To be specific, assume the input as `SparseTensor` looks
+ like:
- Example:
+ ```python
+ shape = [2, 2]
+ {
+ [0, 0]: "a"
+ [1, 0]: "b"
+ [1, 1]: "c"
+ }
+ ```
+ `linear_model` assigns weights for the presence of "a", "b", "c' implicitly,
+ just like `indicator_column`, while `input_layer` explicitly requires wrapping
+ each of categorical columns with an `embedding_column` or an
+ `indicator_column`.
+
+ Example of usage:
```python
price = numeric_column('price')
@@ -374,13 +387,44 @@ def linear_model(features,
to your model. All items should be instances of classes derived from
`_FeatureColumn`s.
units: An integer, dimensionality of the output space. Default value is 1.
- sparse_combiner: A string specifying how to reduce if a sparse column is
- multivalent. Currently "mean", "sqrtn" and "sum" are supported, with "sum"
- the default. "sqrtn" often achieves good accuracy, in particular with
- bag-of-words columns. It combines each sparse columns independently.
+ sparse_combiner: A string specifying how to reduce if a categorical column
+ is multivalent. Except `numeric_column`, almost all columns passed to
+ `linear_model` are considered as categorical columns. It combines each
+ categorical column independently. Currently "mean", "sqrtn" and "sum" are
+ supported, with "sum" the default for linear model. "sqrtn" often achieves
+ good accuracy, in particular with bag-of-words columns.
* "sum": do not normalize features in the column
* "mean": do l1 normalization on features in the column
* "sqrtn": do l2 normalization on features in the column
+ For example, for two features represented as the categorical columns:
+
+ ```python
+ # Feature 1
+
+ shape = [2, 2]
+ {
+ [0, 0]: "a"
+ [0, 1]: "b"
+ [1, 0]: "c"
+ }
+
+ # Feature 2
+
+ shape = [2, 3]
+ {
+ [0, 0]: "d"
+ [1, 0]: "e"
+ [1, 1]: "f"
+ [1, 2]: "g"
+ }
+ ```
+ with `sparse_combiner` as "mean", the linear model outputs conceptly are:
+ ```
+ y_0 = 1.0 / 2.0 * ( w_a + w_ b) + w_c + b_0
+ y_1 = w_d + 1.0 / 3.0 * ( w_e + w_ f + w_g) + b_1
+ ```
+ where `y_i` is the output, `b_i` is the bias, and `w_x` is the weight
+ assigned to the presence of `x` in the input features.
weight_collections: A list of collection names to which the Variable will be
added. Note that, variables will also be added to collections
`tf.GraphKeys.GLOBAL_VARIABLES` and `ops.GraphKeys.MODEL_VARIABLES`.
@@ -536,7 +580,8 @@ class _LinearModel(training.Model):
name=None,
**kwargs):
super(_LinearModel, self).__init__(name=name, **kwargs)
- self._feature_columns = _clean_feature_columns(feature_columns)
+ self._feature_columns = _normalize_feature_columns(
+ feature_columns)
self._weight_collections = list(weight_collections or [])
if ops.GraphKeys.MODEL_VARIABLES not in self._weight_collections:
self._weight_collections.append(ops.GraphKeys.MODEL_VARIABLES)
@@ -643,7 +688,7 @@ def _transform_features(features, feature_columns):
Returns:
A `dict` mapping `_FeatureColumn` to `Tensor` and `SparseTensor` values.
"""
- feature_columns = _clean_feature_columns(feature_columns)
+ feature_columns = _normalize_feature_columns(feature_columns)
outputs = {}
with ops.name_scope(
None, default_name='transform_features', values=features.values()):
@@ -911,7 +956,8 @@ def shared_embedding_columns(
tensor_name_in_ckpt: Name of the `Tensor` in `ckpt_to_load_from` from
which to restore the column weights. Required if `ckpt_to_load_from` is
not `None`.
- max_norm: If not `None`, embedding values are l2-normalized to this value.
+ max_norm: If not `None`, each embedding is clipped if its l2-norm is
+ larger than this value, before combining.
trainable: Whether or not the embedding is trainable. Default is True.
Returns:
@@ -1182,12 +1228,13 @@ def categorical_column_with_hash_bucket(key,
Use this when your sparse features are in string or integer format, and you
want to distribute your inputs into a finite number of buckets by hashing.
- output_id = Hash(input_feature_string) % bucket_size
+ output_id = Hash(input_feature_string) % bucket_size for string type input.
+ For int type input, the value is converted to its string representation first
+ and then hashed by the same formula.
For input dictionary `features`, `features[key]` is either `Tensor` or
`SparseTensor`. If `Tensor`, missing values can be represented by `-1` for int
- and `''` for string. Note that these values are independent of the
- `default_value` argument.
+ and `''` for string, which will be dropped by this feature column.
Example:
@@ -1249,8 +1296,7 @@ def categorical_column_with_vocabulary_file(key,
For input dictionary `features`, `features[key]` is either `Tensor` or
`SparseTensor`. If `Tensor`, missing values can be represented by `-1` for int
- and `''` for string. Note that these values are independent of the
- `default_value` argument.
+ and `''` for string, which will be dropped by this feature column.
Example with `num_oov_buckets`:
File '/us/states.txt' contains 50 lines, each with a 2-character U.S. state
@@ -1366,8 +1412,7 @@ def categorical_column_with_vocabulary_list(
For input dictionary `features`, `features[key]` is either `Tensor` or
`SparseTensor`. If `Tensor`, missing values can be represented by `-1` for int
- and `''` for string. Note that these values are independent of the
- `default_value` argument.
+ and `''` for string, which will be dropped by this feature column.
Example with `num_oov_buckets`:
In the following example, each input in `vocabulary_list` is assigned an ID
@@ -1480,8 +1525,7 @@ def categorical_column_with_identity(key, num_buckets, default_value=None):
For input dictionary `features`, `features[key]` is either `Tensor` or
`SparseTensor`. If `Tensor`, missing values can be represented by `-1` for int
- and `''` for string. Note that these values are independent of the
- `default_value` argument.
+ and `''` for string, which will be dropped by this feature column.
In the following examples, each input in the range `[0, 1000000)` is assigned
the same value. All other inputs are assigned `default_value` 0. Note that a
@@ -1538,8 +1582,14 @@ def categorical_column_with_identity(key, num_buckets, default_value=None):
def indicator_column(categorical_column):
"""Represents multi-hot representation of given categorical column.
- Used to wrap any `categorical_column_*` (e.g., to feed to DNN). Use
- `embedding_column` if the inputs are sparse.
+ - For DNN model, `indicator_column` can be used to wrap any
+ `categorical_column_*` (e.g., to feed to DNN). Consider to Use
+ `embedding_column` if the number of buckets/unique(values) are large.
+
+ - For Wide (aka linear) model, `indicator_column` is the internal
+ representation for categorical column when passing categorical column
+ directly (as any element in feature_columns) to `linear_model`. See
+ `linear_model` for details.
```python
name = indicator_column(categorical_column_with_vocabulary_list(
@@ -1956,7 +2006,7 @@ def _create_weighted_sum(column,
weight_collections,
trainable,
weight_var=None):
- """Creates a weighted sum for a dense or sparse column for linear_model."""
+ """Creates a weighted sum for a dense/categorical column for linear_model."""
if isinstance(column, _CategoricalColumn):
return _create_categorical_column_weighted_sum(
column=column,
@@ -2055,7 +2105,34 @@ def _create_categorical_column_weighted_sum(column,
weight_collections,
trainable,
weight_var=None):
- """Create a weighted sum of a categorical column for linear_model."""
+ # pylint: disable=g-doc-return-or-yield,g-doc-args
+ """Create a weighted sum of a categorical column for linear_model.
+
+ Note to maintainer: As implementation details, the weighted sum is
+ implemented via embedding_lookup_sparse toward efficiency. Mathematically,
+ they are the same.
+
+ To be specific, conceptually, categorical column can be treated as multi-hot
+ vector. Say:
+
+ ```python
+ x = [0 0 1] # categorical column input
+ w = [a b c] # weights
+ ```
+ The weighted sum is `c` in this case, which is same as `w[2]`.
+
+ Another example is
+
+ ```python
+ x = [0 1 1] # categorical column input
+ w = [a b c] # weights
+ ```
+ The weighted sum is `b + c` in this case, which is same as `w[2] + w[3]`.
+
+ For both cases, we can implement weighted sum via embedding_lookup with
+ sparse_combiner = "sum".
+ """
+
sparse_tensors = column._get_sparse_tensors( # pylint: disable=protected-access
builder,
weight_collections=weight_collections,
@@ -2249,7 +2326,7 @@ def _shape_offsets(shape):
# TODO(ptucker): Move to third_party/tensorflow/python/ops/sparse_ops.py
-def _to_sparse_input(input_tensor, ignore_value=None):
+def _to_sparse_input_and_drop_ignore_values(input_tensor, ignore_value=None):
"""Converts a `Tensor` to a `SparseTensor`, dropping ignore_value cells.
If `input_tensor` is already a `SparseTensor`, just return it.
@@ -2293,8 +2370,22 @@ def _to_sparse_input(input_tensor, ignore_value=None):
input_tensor, out_type=dtypes.int64, name='dense_shape'))
-def _clean_feature_columns(feature_columns):
- """Verifies and normalizes `feature_columns` input."""
+def _normalize_feature_columns(feature_columns):
+ """Normalizes the `feature_columns` input.
+
+ This method converts the `feature_columns` to list type as best as it can. In
+ addition, verifies the type and other parts of feature_columns, required by
+ downstream library.
+
+ Args:
+ feature_columns: The raw feature columns, usually passed by users.
+
+ Returns:
+ The normalized feature column list.
+
+ Raises:
+ ValueError: for any invalid inputs, such as empty, duplicated names, etc.
+ """
if isinstance(feature_columns, _FeatureColumn):
feature_columns = [feature_columns]
@@ -2420,6 +2511,7 @@ class _BucketizedColumn(_DenseColumn, _CategoricalColumn,
def _get_sparse_tensors(self, inputs, weight_collections=None,
trainable=None):
+ """Converts dense inputs to SparseTensor so downstream code can use it."""
input_tensor = inputs.get(self)
batch_size = array_ops.shape(input_tensor)[0]
# By construction, source_column is always one-dimensional.
@@ -2804,7 +2896,7 @@ class _HashedCategoricalColumn(
return {self.key: parsing_ops.VarLenFeature(self.dtype)}
def _transform_feature(self, inputs):
- input_tensor = _to_sparse_input(inputs.get(self.key))
+ input_tensor = _to_sparse_input_and_drop_ignore_values(inputs.get(self.key))
if not isinstance(input_tensor, sparse_tensor_lib.SparseTensor):
raise ValueError('SparseColumn input must be a SparseTensor.')
@@ -2855,7 +2947,7 @@ class _VocabularyFileCategoricalColumn(
return {self.key: parsing_ops.VarLenFeature(self.dtype)}
def _transform_feature(self, inputs):
- input_tensor = _to_sparse_input(inputs.get(self.key))
+ input_tensor = _to_sparse_input_and_drop_ignore_values(inputs.get(self.key))
if self.dtype.is_integer != input_tensor.dtype.is_integer:
raise ValueError(
@@ -2907,7 +2999,7 @@ class _VocabularyListCategoricalColumn(
return {self.key: parsing_ops.VarLenFeature(self.dtype)}
def _transform_feature(self, inputs):
- input_tensor = _to_sparse_input(inputs.get(self.key))
+ input_tensor = _to_sparse_input_and_drop_ignore_values(inputs.get(self.key))
if self.dtype.is_integer != input_tensor.dtype.is_integer:
raise ValueError(
@@ -2959,7 +3051,7 @@ class _IdentityCategoricalColumn(
return {self.key: parsing_ops.VarLenFeature(dtypes.int64)}
def _transform_feature(self, inputs):
- input_tensor = _to_sparse_input(inputs.get(self.key))
+ input_tensor = _to_sparse_input_and_drop_ignore_values(inputs.get(self.key))
if not input_tensor.dtype.is_integer:
raise ValueError(
@@ -3041,7 +3133,8 @@ class _WeightedCategoricalColumn(
self.dtype, weight_tensor.dtype))
if not isinstance(weight_tensor, sparse_tensor_lib.SparseTensor):
# The weight tensor can be a regular Tensor. In this case, sparsify it.
- weight_tensor = _to_sparse_input(weight_tensor, ignore_value=0.0)
+ weight_tensor = _to_sparse_input_and_drop_ignore_values(
+ weight_tensor, ignore_value=0.0)
if not weight_tensor.dtype.is_floating:
weight_tensor = math_ops.to_float(weight_tensor)
return (inputs.get(self.categorical_column), weight_tensor)
@@ -3486,3 +3579,8 @@ class _SequenceCategoricalColumn(
weight_tensor,
shape=array_ops.concat([weight_tensor.dense_shape, [1]], axis=0))
return _CategoricalColumn.IdWeightPair(id_tensor, weight_tensor)
+
+
+# TODO(xiejw): Remove the following alias once call sites are updated.
+_clean_feature_columns = _normalize_feature_columns
+_to_sparse_input = _to_sparse_input_and_drop_ignore_values