aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/python/feature_column/feature_column_v2.py
diff options
context:
space:
mode:
Diffstat (limited to 'tensorflow/python/feature_column/feature_column_v2.py')
-rw-r--r--tensorflow/python/feature_column/feature_column_v2.py869
1 files changed, 705 insertions, 164 deletions
diff --git a/tensorflow/python/feature_column/feature_column_v2.py b/tensorflow/python/feature_column/feature_column_v2.py
index b79373c475..6d089de991 100644
--- a/tensorflow/python/feature_column/feature_column_v2.py
+++ b/tensorflow/python/feature_column/feature_column_v2.py
@@ -136,6 +136,7 @@ import six
from tensorflow.python.eager import context
+from tensorflow.python.feature_column import feature_column as fc_old
from tensorflow.python.framework import dtypes
from tensorflow.python.framework import ops
from tensorflow.python.framework import sparse_tensor as sparse_tensor_lib
@@ -157,9 +158,16 @@ from tensorflow.python.ops import variables
from tensorflow.python.platform import gfile
from tensorflow.python.platform import tf_logging as logging
from tensorflow.python.training import checkpoint_utils
+from tensorflow.python.util import deprecation
from tensorflow.python.util import nest
+_FEATURE_COLUMN_DEPRECATION_DATE = '2018-11-30'
+_FEATURE_COLUMN_DEPRECATION = ('The old _FeatureColumn APIs are being '
+ 'deprecated. Please use the new FeatureColumn '
+ 'APIs instead.')
+
+
class StateManager(object):
"""Manages the state associated with FeatureColumns.
@@ -440,10 +448,6 @@ class FeatureLayer(Layer):
return (input_shape[0], total_elements)
-def _strip_leading_slashes(name):
- return name.rsplit('/', 1)[-1]
-
-
class LinearModel(Layer):
"""Produces a linear prediction `Tensor` based on given `feature_columns`.
@@ -775,12 +779,12 @@ def embedding_column(
categorical_column, dimension, combiner='mean', initializer=None,
ckpt_to_load_from=None, tensor_name_in_ckpt=None, max_norm=None,
trainable=True):
- """`_DenseColumn` that converts from sparse, categorical input.
+ """`DenseColumn` that converts from sparse, categorical input.
Use this when your inputs are sparse, but you want to convert them to a dense
representation (e.g., to feed to a DNN).
- Inputs must be a `_CategoricalColumn` created by any of the
+ Inputs must be a `CategoricalColumn` created by any of the
`categorical_column_*` function. Here is an example of using
`embedding_column` with `DNNClassifier`:
@@ -814,12 +818,12 @@ def embedding_column(
```
Args:
- categorical_column: A `_CategoricalColumn` created by a
+ categorical_column: A `CategoricalColumn` created by a
`categorical_column_with_*` function. This column produces the sparse IDs
that are inputs to the embedding lookup.
dimension: An integer specifying dimension of the embedding, must be > 0.
- combiner: A string specifying how to reduce if there are multiple entries
- in a single row. Currently 'mean', 'sqrtn' and 'sum' are supported, with
+ combiner: A string specifying how to reduce if there are multiple entries in
+ a single row. Currently 'mean', 'sqrtn' and 'sum' are supported, with
'mean' the default. 'sqrtn' often achieves good accuracy, in particular
with bag-of-words columns. Each of this can be thought as example level
normalizations on the column. For more information, see
@@ -830,14 +834,14 @@ def embedding_column(
`1/sqrt(dimension)`.
ckpt_to_load_from: String representing checkpoint name/pattern from which to
restore column weights. Required if `tensor_name_in_ckpt` is not `None`.
- tensor_name_in_ckpt: Name of the `Tensor` in `ckpt_to_load_from` from
- which to restore the column weights. Required if `ckpt_to_load_from` is
- not `None`.
+ tensor_name_in_ckpt: Name of the `Tensor` in `ckpt_to_load_from` from which
+ to restore the column weights. Required if `ckpt_to_load_from` is not
+ `None`.
max_norm: If not `None`, embedding values are l2-normalized to this value.
trainable: Whether or not the embedding is trainable. Default is True.
Returns:
- `_DenseColumn` that converts from sparse input.
+ `DenseColumn` that converts from sparse input.
Raises:
ValueError: if `dimension` not > 0.
@@ -1181,7 +1185,7 @@ def bucketized_column(source_column, boundaries):
one-dimensional.
ValueError: If `boundaries` is not a sorted list or tuple.
"""
- if not isinstance(source_column, NumericColumn):
+ if not isinstance(source_column, (NumericColumn, fc_old._NumericColumn)): # pylint: disable=protected-access
raise ValueError(
'source_column must be a column generated with numeric_column(). '
'Given: {}'.format(source_column))
@@ -1390,7 +1394,7 @@ def categorical_column_with_vocabulary_file(key,
def categorical_column_with_vocabulary_list(
key, vocabulary_list, dtype=None, default_value=-1, num_oov_buckets=0):
- """A `_CategoricalColumn` with in-memory vocabulary.
+ """A `CategoricalColumn` with in-memory vocabulary.
Use this when your inputs are in string or integer format, and you have an
in-memory vocabulary mapping each value to an integer ID. By default,
@@ -1439,14 +1443,14 @@ def categorical_column_with_vocabulary_list(
```
Args:
- key: A unique string identifying the input feature. It is used as the
- column name and the dictionary key for feature parsing configs, feature
- `Tensor` objects, and feature columns.
+ key: A unique string identifying the input feature. It is used as the column
+ name and the dictionary key for feature parsing configs, feature `Tensor`
+ objects, and feature columns.
vocabulary_list: An ordered iterable defining the vocabulary. Each feature
is mapped to the index of its value (if present) in `vocabulary_list`.
Must be castable to `dtype`.
- dtype: The type of features. Only string and integer types are supported.
- If `None`, it will be inferred from `vocabulary_list`.
+ dtype: The type of features. Only string and integer types are supported. If
+ `None`, it will be inferred from `vocabulary_list`.
default_value: The integer ID value to return for out-of-vocabulary feature
values, defaults to `-1`. This can not be specified with a positive
`num_oov_buckets`.
@@ -1604,7 +1608,7 @@ def indicator_column(categorical_column):
def weighted_categorical_column(
categorical_column, weight_feature_key, dtype=dtypes.float32):
- """Applies weight values to a `_CategoricalColumn`.
+ """Applies weight values to a `CategoricalColumn`.
Use this when each of your sparse inputs has both an ID and a value. For
example, if you're representing text documents as a collection of word
@@ -1655,7 +1659,7 @@ def weighted_categorical_column(
the same indices and dense shape.
Args:
- categorical_column: A `_CategoricalColumn` created by
+ categorical_column: A `CategoricalColumn` created by
`categorical_column_with_*` functions.
weight_feature_key: String key for weight values.
dtype: Type of weights, such as `tf.float32`. Only float and integer weights
@@ -1788,12 +1792,13 @@ def crossed_column(keys, hash_bucket_size, hash_key=None):
'keys must be a list with length > 1. Given: {}'.format(keys))
for key in keys:
if (not isinstance(key, six.string_types) and
- not isinstance(key, CategoricalColumn)):
+ not isinstance(key, (CategoricalColumn, fc_old._CategoricalColumn))): # pylint: disable=protected-access
raise ValueError(
'Unsupported key type. All keys must be either string, or '
'categorical column except HashedCategoricalColumn. '
'Given: {}'.format(key))
- if isinstance(key, HashedCategoricalColumn):
+ if isinstance(key,
+ (HashedCategoricalColumn, fc_old._HashedCategoricalColumn)): # pylint: disable=protected-access
raise ValueError(
'categorical_column_with_hash_bucket is not supported for crossing. '
'Hashing before crossing will increase probability of collision. '
@@ -1882,6 +1887,16 @@ class FeatureColumn(object):
"""
pass
+ @abc.abstractproperty
+ def _is_v2_column(self):
+ """Returns whether this FeatureColumn is fully conformant to the new API.
+
+ This is needed for composition type cases where an EmbeddingColumn etc.
+ might take in old categorical columns as input and then we want to use the
+ old API.
+ """
+ pass
+
class DenseColumn(FeatureColumn):
"""Represents a column which can be represented as `Tensor`.
@@ -1927,6 +1942,8 @@ def is_feature_column_v2(feature_columns):
for feature_column in feature_columns:
if not isinstance(feature_column, FeatureColumn):
return False
+ if not feature_column._is_v2_column: # pylint: disable=protected-access
+ return False
return True
@@ -2202,19 +2219,6 @@ class FeatureTransformationCache(object):
# TODO(ptucker): Move to third_party/tensorflow/python/ops/sparse_ops.py
-def _shape_offsets(shape):
- """Returns moving offset for each dimension given shape."""
- offsets = []
- for dim in reversed(shape):
- if offsets:
- offsets.append(dim * offsets[-1])
- else:
- offsets.append(dim)
- offsets.reverse()
- return offsets
-
-
-# TODO(ptucker): Move to third_party/tensorflow/python/ops/sparse_ops.py
def _to_sparse_input_and_drop_ignore_values(input_tensor, ignore_value=None):
"""Converts a `Tensor` to a `SparseTensor`, dropping ignore_value cells.
@@ -2306,12 +2310,17 @@ def _normalize_feature_columns(feature_columns):
class NumericColumn(
DenseColumn,
+ fc_old._DenseColumn, # pylint: disable=protected-access
collections.namedtuple(
'NumericColumn',
('key', 'shape', 'default_value', 'dtype', 'normalizer_fn'))):
"""see `numeric_column`."""
@property
+ def _is_v2_column(self):
+ return True
+
+ @property
def name(self):
"""See `FeatureColumn` base class."""
return self.key
@@ -2325,6 +2334,27 @@ class NumericColumn(
self.default_value)
}
+ @property
+ @deprecation.deprecated(_FEATURE_COLUMN_DEPRECATION_DATE,
+ _FEATURE_COLUMN_DEPRECATION)
+ def _parse_example_spec(self):
+ return self.parse_example_spec
+
+ def _transform_input_tensor(self, input_tensor):
+ if isinstance(input_tensor, sparse_tensor_lib.SparseTensor):
+ raise ValueError(
+ 'The corresponding Tensor of numerical column must be a Tensor. '
+ 'SparseTensor is not supported. key: {}'.format(self.key))
+ if self.normalizer_fn is not None:
+ input_tensor = self.normalizer_fn(input_tensor)
+ return math_ops.to_float(input_tensor)
+
+ @deprecation.deprecated(_FEATURE_COLUMN_DEPRECATION_DATE,
+ _FEATURE_COLUMN_DEPRECATION)
+ def _transform_feature(self, inputs):
+ input_tensor = inputs.get(self.key)
+ return self._transform_input_tensor(input_tensor)
+
def transform_feature(self, transformation_cache, state_manager):
"""See `FeatureColumn` base class.
@@ -2342,19 +2372,19 @@ class NumericColumn(
ValueError: If a SparseTensor is passed in.
"""
input_tensor = transformation_cache.get(self.key, state_manager)
- if isinstance(input_tensor, sparse_tensor_lib.SparseTensor):
- raise ValueError(
- 'The corresponding Tensor of numerical column must be a Tensor. '
- 'SparseTensor is not supported. key: {}'.format(self.key))
- if self.normalizer_fn is not None:
- input_tensor = self.normalizer_fn(input_tensor)
- return math_ops.to_float(input_tensor)
+ return self._transform_input_tensor(input_tensor)
@property
def variable_shape(self):
"""See `DenseColumn` base class."""
return tensor_shape.TensorShape(self.shape)
+ @property
+ @deprecation.deprecated(_FEATURE_COLUMN_DEPRECATION_DATE,
+ _FEATURE_COLUMN_DEPRECATION)
+ def _variable_shape(self):
+ return self.variable_shape
+
def get_dense_tensor(self, transformation_cache, state_manager):
"""Returns dense `Tensor` representing numeric feature.
@@ -2371,13 +2401,29 @@ class NumericColumn(
# representation created by _transform_feature.
return transformation_cache.get(self, state_manager)
+ @deprecation.deprecated(_FEATURE_COLUMN_DEPRECATION_DATE,
+ _FEATURE_COLUMN_DEPRECATION)
+ def _get_dense_tensor(self, inputs, weight_collections=None, trainable=None):
+ del weight_collections
+ del trainable
+ return inputs.get(self)
+
-class BucketizedColumn(DenseColumn, CategoricalColumn,
- collections.namedtuple('BucketizedColumn',
- ('source_column', 'boundaries'))):
+class BucketizedColumn(
+ DenseColumn,
+ CategoricalColumn,
+ fc_old._DenseColumn, # pylint: disable=protected-access
+ fc_old._CategoricalColumn, # pylint: disable=protected-access
+ collections.namedtuple('BucketizedColumn',
+ ('source_column', 'boundaries'))):
"""See `bucketized_column`."""
@property
+ def _is_v2_column(self):
+ return (isinstance(self.source_column, FeatureColumn) and
+ self.source_column._is_v2_column) # pylint: disable=protected-access
+
+ @property
def name(self):
"""See `FeatureColumn` base class."""
return '{}_bucketized'.format(self.source_column.name)
@@ -2387,6 +2433,21 @@ class BucketizedColumn(DenseColumn, CategoricalColumn,
"""See `FeatureColumn` base class."""
return self.source_column.parse_example_spec
+ @property
+ @deprecation.deprecated(_FEATURE_COLUMN_DEPRECATION_DATE,
+ _FEATURE_COLUMN_DEPRECATION)
+ def _parse_example_spec(self):
+ return self.source_column._parse_example_spec # pylint: disable=protected-access
+
+ @deprecation.deprecated(_FEATURE_COLUMN_DEPRECATION_DATE,
+ _FEATURE_COLUMN_DEPRECATION)
+ def _transform_feature(self, inputs):
+ """Returns bucketized categorical `source_column` tensor."""
+ source_tensor = inputs.get(self.source_column)
+ return math_ops._bucketize( # pylint: disable=protected-access
+ source_tensor,
+ boundaries=self.boundaries)
+
def transform_feature(self, transformation_cache, state_manager):
"""Returns bucketized categorical `source_column` tensor."""
source_tensor = transformation_cache.get(self.source_column, state_manager)
@@ -2400,24 +2461,45 @@ class BucketizedColumn(DenseColumn, CategoricalColumn,
return tensor_shape.TensorShape(
tuple(self.source_column.shape) + (len(self.boundaries) + 1,))
- def get_dense_tensor(self, transformation_cache, state_manager):
- """Returns one hot encoded dense `Tensor`."""
- input_tensor = transformation_cache.get(self, state_manager)
+ @property
+ @deprecation.deprecated(_FEATURE_COLUMN_DEPRECATION_DATE,
+ _FEATURE_COLUMN_DEPRECATION)
+ def _variable_shape(self):
+ return self.variable_shape
+
+ def _get_dense_tensor_for_input_tensor(self, input_tensor):
return array_ops.one_hot(
indices=math_ops.to_int64(input_tensor),
depth=len(self.boundaries) + 1,
on_value=1.,
off_value=0.)
+ def get_dense_tensor(self, transformation_cache, state_manager):
+ """Returns one hot encoded dense `Tensor`."""
+ input_tensor = transformation_cache.get(self, state_manager)
+ return self._get_dense_tensor_for_input_tensor(input_tensor)
+
+ @deprecation.deprecated(_FEATURE_COLUMN_DEPRECATION_DATE,
+ _FEATURE_COLUMN_DEPRECATION)
+ def _get_dense_tensor(self, inputs, weight_collections=None, trainable=None):
+ del weight_collections
+ del trainable
+ input_tensor = inputs.get(self)
+ return self._get_dense_tensor_for_input_tensor(input_tensor)
+
@property
def num_buckets(self):
"""See `CategoricalColumn` base class."""
# By construction, source_column is always one-dimensional.
return (len(self.boundaries) + 1) * self.source_column.shape[0]
- def get_sparse_tensors(self, transformation_cache, state_manager):
- """Converts dense inputs to SparseTensor so downstream code can use it."""
- input_tensor = transformation_cache.get(self, state_manager)
+ @property
+ @deprecation.deprecated(_FEATURE_COLUMN_DEPRECATION_DATE,
+ _FEATURE_COLUMN_DEPRECATION)
+ def _num_buckets(self):
+ return self.num_buckets
+
+ def _get_sparse_tensors_for_input_tensor(self, input_tensor):
batch_size = array_ops.shape(input_tensor)[0]
# By construction, source_column is always one-dimensional.
source_dimension = self.source_column.shape[0]
@@ -2443,9 +2525,27 @@ class BucketizedColumn(DenseColumn, CategoricalColumn,
dense_shape=dense_shape)
return CategoricalColumn.IdWeightPair(sparse_tensor, None)
+ def get_sparse_tensors(self, transformation_cache, state_manager):
+ """Converts dense inputs to SparseTensor so downstream code can use it."""
+ input_tensor = transformation_cache.get(self, state_manager)
+ return self._get_sparse_tensors_for_input_tensor(input_tensor)
+
+ @deprecation.deprecated(_FEATURE_COLUMN_DEPRECATION_DATE,
+ _FEATURE_COLUMN_DEPRECATION)
+ def _get_sparse_tensors(self, inputs, weight_collections=None,
+ trainable=None):
+ """Converts dense inputs to SparseTensor so downstream code can use it."""
+ del weight_collections
+ del trainable
+ input_tensor = inputs.get(self)
+ return self._get_sparse_tensors_for_input_tensor(input_tensor)
+
class EmbeddingColumn(
- DenseColumn, SequenceDenseColumn,
+ DenseColumn,
+ SequenceDenseColumn,
+ fc_old._DenseColumn, # pylint: disable=protected-access
+ fc_old._SequenceDenseColumn, # pylint: disable=protected-access
collections.namedtuple(
'EmbeddingColumn',
('categorical_column', 'dimension', 'combiner', 'initializer',
@@ -2453,6 +2553,11 @@ class EmbeddingColumn(
"""See `embedding_column`."""
@property
+ def _is_v2_column(self):
+ return (isinstance(self.categorical_column, FeatureColumn) and
+ self.categorical_column._is_v2_column) # pylint: disable=protected-access
+
+ @property
def name(self):
"""See `FeatureColumn` base class."""
return '{}_embedding'.format(self.categorical_column.name)
@@ -2462,18 +2567,35 @@ class EmbeddingColumn(
"""See `FeatureColumn` base class."""
return self.categorical_column.parse_example_spec
+ @property
+ @deprecation.deprecated(_FEATURE_COLUMN_DEPRECATION_DATE,
+ _FEATURE_COLUMN_DEPRECATION)
+ def _parse_example_spec(self):
+ return self.categorical_column._parse_example_spec # pylint: disable=protected-access
+
def transform_feature(self, transformation_cache, state_manager):
"""Transforms underlying `categorical_column`."""
return transformation_cache.get(self.categorical_column, state_manager)
+ @deprecation.deprecated(_FEATURE_COLUMN_DEPRECATION_DATE,
+ _FEATURE_COLUMN_DEPRECATION)
+ def _transform_feature(self, inputs):
+ return inputs.get(self.categorical_column)
+
@property
def variable_shape(self):
"""See `DenseColumn` base class."""
return tensor_shape.vector(self.dimension)
+ @property
+ @deprecation.deprecated(_FEATURE_COLUMN_DEPRECATION_DATE,
+ _FEATURE_COLUMN_DEPRECATION)
+ def _variable_shape(self):
+ return self.variable_shape
+
def create_state(self, state_manager):
"""Creates the embedding lookup variable."""
- embedding_shape = (self.categorical_column.num_buckets, self.dimension)
+ embedding_shape = (self.categorical_column._num_buckets, self.dimension) # pylint: disable=protected-access
state_manager.create_variable(
self,
name='embedding_weights',
@@ -2482,17 +2604,11 @@ class EmbeddingColumn(
trainable=self.trainable,
initializer=self.initializer)
- def _get_dense_tensor_internal(self, transformation_cache, state_manager):
- """Private method that follows the signature of _get_dense_tensor."""
- # Get sparse IDs and weights.
- sparse_tensors = self.categorical_column.get_sparse_tensors(
- transformation_cache, state_manager)
+ def _get_dense_tensor_internal_helper(self, sparse_tensors,
+ embedding_weights):
sparse_ids = sparse_tensors.id_tensor
sparse_weights = sparse_tensors.weight_tensor
- embedding_weights = state_manager.get_variable(
- self, name='embedding_weights')
-
if self.ckpt_to_load_from is not None:
to_restore = embedding_weights
if isinstance(to_restore, variables.PartitionedVariable):
@@ -2510,6 +2626,30 @@ class EmbeddingColumn(
name='%s_weights' % self.name,
max_norm=self.max_norm)
+ def _get_dense_tensor_internal(self, sparse_tensors, state_manager):
+ """Private method that follows the signature of get_dense_tensor."""
+ embedding_weights = state_manager.get_variable(
+ self, name='embedding_weights')
+ return self._get_dense_tensor_internal_helper(sparse_tensors,
+ embedding_weights)
+
+ def _old_get_dense_tensor_internal(self, sparse_tensors, weight_collections,
+ trainable):
+ """Private method that follows the signature of _get_dense_tensor."""
+ embedding_shape = (self.categorical_column._num_buckets, self.dimension) # pylint: disable=protected-access
+ if (weight_collections and
+ ops.GraphKeys.GLOBAL_VARIABLES not in weight_collections):
+ weight_collections.append(ops.GraphKeys.GLOBAL_VARIABLES)
+ embedding_weights = variable_scope.get_variable(
+ name='embedding_weights',
+ shape=embedding_shape,
+ dtype=dtypes.float32,
+ initializer=self.initializer,
+ trainable=self.trainable and trainable,
+ collections=weight_collections)
+ return self._get_dense_tensor_internal_helper(sparse_tensors,
+ embedding_weights)
+
def get_dense_tensor(self, transformation_cache, state_manager):
"""Returns tensor after doing the embedding lookup.
@@ -2535,7 +2675,30 @@ class EmbeddingColumn(
'sequence_input_layer instead of input_layer. '
'Given (type {}): {}'.format(self.name, type(self.categorical_column),
self.categorical_column))
- return self._get_dense_tensor_internal(transformation_cache, state_manager)
+ # Get sparse IDs and weights.
+ sparse_tensors = self.categorical_column.get_sparse_tensors(
+ transformation_cache, state_manager)
+ return self._get_dense_tensor_internal(sparse_tensors, state_manager)
+
+ @deprecation.deprecated(_FEATURE_COLUMN_DEPRECATION_DATE,
+ _FEATURE_COLUMN_DEPRECATION)
+ def _get_dense_tensor(self, inputs, weight_collections=None, trainable=None):
+ if isinstance(
+ self.categorical_column,
+ (SequenceCategoricalColumn, fc_old._SequenceCategoricalColumn)): # pylint: disable=protected-access
+ raise ValueError(
+ 'In embedding_column: {}. '
+ 'categorical_column must not be of type _SequenceCategoricalColumn. '
+ 'Suggested fix A: If you wish to use input_layer, use a '
+ 'non-sequence categorical_column_with_*. '
+ 'Suggested fix B: If you wish to create sequence input, use '
+ 'sequence_input_layer instead of input_layer. '
+ 'Given (type {}): {}'.format(self.name, type(self.categorical_column),
+ self.categorical_column))
+ sparse_tensors = self.categorical_column._get_sparse_tensors( # pylint: disable=protected-access
+ inputs, weight_collections, trainable)
+ return self._old_get_dense_tensor_internal(sparse_tensors,
+ weight_collections, trainable)
def get_sequence_dense_tensor(self, transformation_cache, state_manager):
"""See `SequenceDenseColumn` base class."""
@@ -2547,21 +2710,40 @@ class EmbeddingColumn(
'Suggested fix: Use one of sequence_categorical_column_with_*. '
'Given (type {}): {}'.format(self.name, type(self.categorical_column),
self.categorical_column))
- dense_tensor = self._get_dense_tensor_internal( # pylint: disable=protected-access
+ sparse_tensors = self.categorical_column.get_sequence_sparse_tensors(
transformation_cache, state_manager)
- sparse_tensors = self.categorical_column.get_sparse_tensors(
- transformation_cache, state_manager)
- sequence_length = _sequence_length_from_sparse_tensor(
+ dense_tensor = self._get_dense_tensor_internal(sparse_tensors,
+ state_manager)
+ sequence_length = fc_old._sequence_length_from_sparse_tensor( # pylint: disable=protected-access
sparse_tensors.id_tensor)
return SequenceDenseColumn.TensorSequenceLengthPair(
dense_tensor=dense_tensor, sequence_length=sequence_length)
-
-def _get_graph_for_variable(var):
- if isinstance(var, variables.PartitionedVariable):
- return list(var)[0].graph
- else:
- return var.graph
+ @deprecation.deprecated(_FEATURE_COLUMN_DEPRECATION_DATE,
+ _FEATURE_COLUMN_DEPRECATION)
+ def _get_sequence_dense_tensor(self,
+ inputs,
+ weight_collections=None,
+ trainable=None):
+ if not isinstance(
+ self.categorical_column,
+ (SequenceCategoricalColumn, fc_old._SequenceCategoricalColumn)): # pylint: disable=protected-access
+ raise ValueError(
+ 'In embedding_column: {}. '
+ 'categorical_column must be of type _SequenceCategoricalColumn '
+ 'to use sequence_input_layer. '
+ 'Suggested fix: Use one of sequence_categorical_column_with_*. '
+ 'Given (type {}): {}'.format(self.name, type(self.categorical_column),
+ self.categorical_column))
+ sparse_tensors = self.categorical_column._get_sparse_tensors(inputs) # pylint: disable=protected-access
+ dense_tensor = self._old_get_dense_tensor_internal(
+ sparse_tensors,
+ weight_collections=weight_collections,
+ trainable=trainable)
+ sequence_length = fc_old._sequence_length_from_sparse_tensor( # pylint: disable=protected-access
+ sparse_tensors.id_tensor)
+ return SequenceDenseColumn.TensorSequenceLengthPair(
+ dense_tensor=dense_tensor, sequence_length=sequence_length)
class SharedEmbeddingStateManager(Layer):
@@ -2633,8 +2815,17 @@ def maybe_create_shared_state_manager(feature_columns):
return None
+def _raise_shared_embedding_column_error():
+ raise ValueError('SharedEmbeddingColumns are not supported in '
+ '`linear_model` or `input_layer`. Please use '
+ '`FeatureLayer` or `LinearModel` instead.')
+
+
class SharedEmbeddingColumn(
- DenseColumn, SequenceDenseColumn,
+ DenseColumn,
+ SequenceDenseColumn,
+ fc_old._DenseColumn, # pylint: disable=protected-access
+ fc_old._SequenceDenseColumn, # pylint: disable=protected-access
collections.namedtuple(
'SharedEmbeddingColumn',
('categorical_column', 'dimension', 'combiner', 'initializer',
@@ -2643,6 +2834,10 @@ class SharedEmbeddingColumn(
"""See `embedding_column`."""
@property
+ def _is_v2_column(self):
+ return True
+
+ @property
def name(self):
"""See `FeatureColumn` base class."""
return '{}_shared_embedding'.format(self.categorical_column.name)
@@ -2662,15 +2857,26 @@ class SharedEmbeddingColumn(
"""See `FeatureColumn` base class."""
return self.categorical_column.parse_example_spec
+ @property
+ def _parse_example_spec(self):
+ return _raise_shared_embedding_column_error()
+
def transform_feature(self, transformation_cache, state_manager):
"""See `FeatureColumn` base class."""
return transformation_cache.get(self.categorical_column, state_manager)
+ def _transform_feature(self, inputs):
+ return _raise_shared_embedding_column_error()
+
@property
def variable_shape(self):
"""See `DenseColumn` base class."""
return tensor_shape.vector(self.dimension)
+ @property
+ def _variable_shape(self):
+ return _raise_shared_embedding_column_error()
+
def create_state(self, state_manager):
"""Creates the shared embedding lookup variable."""
if not isinstance(state_manager, SharedEmbeddingStateManager):
@@ -2731,6 +2937,9 @@ class SharedEmbeddingColumn(
self.categorical_column))
return self._get_dense_tensor_internal(transformation_cache, state_manager)
+ def _get_dense_tensor(self, inputs, weight_collections=None, trainable=None):
+ return _raise_shared_embedding_column_error()
+
def get_sequence_dense_tensor(self, transformation_cache, state_manager):
"""See `SequenceDenseColumn` base class."""
if not isinstance(self.categorical_column, SequenceCategoricalColumn):
@@ -2745,11 +2954,17 @@ class SharedEmbeddingColumn(
state_manager)
sparse_tensors = self.categorical_column.get_sparse_tensors(
transformation_cache, state_manager)
- sequence_length = _sequence_length_from_sparse_tensor(
+ sequence_length = fc_old._sequence_length_from_sparse_tensor( # pylint: disable=protected-access
sparse_tensors.id_tensor)
return SequenceDenseColumn.TensorSequenceLengthPair(
dense_tensor=dense_tensor, sequence_length=sequence_length)
+ def _get_sequence_dense_tensor(self,
+ inputs,
+ weight_collections=None,
+ trainable=None):
+ return _raise_shared_embedding_column_error()
+
def _create_tuple(shape, value):
"""Returns a tuple with given shape and filled with value."""
@@ -2858,11 +3073,16 @@ def _check_default_value(shape, default_value, dtype, key):
class HashedCategoricalColumn(
CategoricalColumn,
+ fc_old._CategoricalColumn, # pylint: disable=protected-access
collections.namedtuple('HashedCategoricalColumn',
('key', 'hash_bucket_size', 'dtype'))):
"""see `categorical_column_with_hash_bucket`."""
@property
+ def _is_v2_column(self):
+ return True
+
+ @property
def name(self):
"""See `FeatureColumn` base class."""
return self.key
@@ -2872,10 +3092,14 @@ class HashedCategoricalColumn(
"""See `FeatureColumn` base class."""
return {self.key: parsing_ops.VarLenFeature(self.dtype)}
- def transform_feature(self, transformation_cache, state_manager):
+ @property
+ @deprecation.deprecated(_FEATURE_COLUMN_DEPRECATION_DATE,
+ _FEATURE_COLUMN_DEPRECATION)
+ def _parse_example_spec(self):
+ return self.parse_example_spec
+
+ def _transform_input_tensor(self, input_tensor):
"""Hashes the values in the feature_column."""
- input_tensor = _to_sparse_input_and_drop_ignore_values(
- transformation_cache.get(self.key, state_manager))
if not isinstance(input_tensor, sparse_tensor_lib.SparseTensor):
raise ValueError('SparseColumn input must be a SparseTensor.')
@@ -2899,25 +3123,56 @@ class HashedCategoricalColumn(
return sparse_tensor_lib.SparseTensor(
input_tensor.indices, sparse_id_values, input_tensor.dense_shape)
+ def transform_feature(self, transformation_cache, state_manager):
+ """Hashes the values in the feature_column."""
+ input_tensor = _to_sparse_input_and_drop_ignore_values(
+ transformation_cache.get(self.key, state_manager))
+ return self._transform_input_tensor(input_tensor)
+
+ @deprecation.deprecated(_FEATURE_COLUMN_DEPRECATION_DATE,
+ _FEATURE_COLUMN_DEPRECATION)
+ def _transform_feature(self, inputs):
+ input_tensor = _to_sparse_input_and_drop_ignore_values(inputs.get(self.key))
+ return self._transform_input_tensor(input_tensor)
+
@property
def num_buckets(self):
"""Returns number of buckets in this sparse feature."""
return self.hash_bucket_size
+ @property
+ @deprecation.deprecated(_FEATURE_COLUMN_DEPRECATION_DATE,
+ _FEATURE_COLUMN_DEPRECATION)
+ def _num_buckets(self):
+ return self.num_buckets
+
def get_sparse_tensors(self, transformation_cache, state_manager):
"""See `CategoricalColumn` base class."""
return CategoricalColumn.IdWeightPair(
transformation_cache.get(self, state_manager), None)
+ @deprecation.deprecated(_FEATURE_COLUMN_DEPRECATION_DATE,
+ _FEATURE_COLUMN_DEPRECATION)
+ def _get_sparse_tensors(self, inputs, weight_collections=None,
+ trainable=None):
+ del weight_collections
+ del trainable
+ return CategoricalColumn.IdWeightPair(inputs.get(self), None)
+
class VocabularyFileCategoricalColumn(
CategoricalColumn,
+ fc_old._CategoricalColumn, # pylint: disable=protected-access
collections.namedtuple('VocabularyFileCategoricalColumn',
('key', 'vocabulary_file', 'vocabulary_size',
'num_oov_buckets', 'dtype', 'default_value'))):
"""See `categorical_column_with_vocabulary_file`."""
@property
+ def _is_v2_column(self):
+ return True
+
+ @property
def name(self):
"""See `FeatureColumn` base class."""
return self.key
@@ -2927,11 +3182,14 @@ class VocabularyFileCategoricalColumn(
"""See `FeatureColumn` base class."""
return {self.key: parsing_ops.VarLenFeature(self.dtype)}
- def transform_feature(self, transformation_cache, state_manager):
- """Creates a lookup table for the vocabulary."""
- input_tensor = _to_sparse_input_and_drop_ignore_values(
- transformation_cache.get(self.key, state_manager))
+ @property
+ @deprecation.deprecated(_FEATURE_COLUMN_DEPRECATION_DATE,
+ _FEATURE_COLUMN_DEPRECATION)
+ def _parse_example_spec(self):
+ return self.parse_example_spec
+ def _transform_input_tensor(self, input_tensor):
+ """Creates a lookup table for the vocabulary."""
if self.dtype.is_integer != input_tensor.dtype.is_integer:
raise ValueError(
'Column dtype and SparseTensors dtype must be compatible. '
@@ -2957,19 +3215,46 @@ class VocabularyFileCategoricalColumn(
key_dtype=key_dtype,
name='{}_lookup'.format(self.key)).lookup(input_tensor)
+ def transform_feature(self, transformation_cache, state_manager):
+ """Creates a lookup table for the vocabulary."""
+ input_tensor = _to_sparse_input_and_drop_ignore_values(
+ transformation_cache.get(self.key, state_manager))
+ return self._transform_input_tensor(input_tensor)
+
+ @deprecation.deprecated(_FEATURE_COLUMN_DEPRECATION_DATE,
+ _FEATURE_COLUMN_DEPRECATION)
+ def _transform_feature(self, inputs):
+ input_tensor = _to_sparse_input_and_drop_ignore_values(inputs.get(self.key))
+ return self._transform_input_tensor(input_tensor)
+
@property
def num_buckets(self):
"""Returns number of buckets in this sparse feature."""
return self.vocabulary_size + self.num_oov_buckets
+ @property
+ @deprecation.deprecated(_FEATURE_COLUMN_DEPRECATION_DATE,
+ _FEATURE_COLUMN_DEPRECATION)
+ def _num_buckets(self):
+ return self.num_buckets
+
def get_sparse_tensors(self, transformation_cache, state_manager):
"""See `CategoricalColumn` base class."""
return CategoricalColumn.IdWeightPair(
transformation_cache.get(self, state_manager), None)
+ @deprecation.deprecated(_FEATURE_COLUMN_DEPRECATION_DATE,
+ _FEATURE_COLUMN_DEPRECATION)
+ def _get_sparse_tensors(self, inputs, weight_collections=None,
+ trainable=None):
+ del weight_collections
+ del trainable
+ return CategoricalColumn.IdWeightPair(inputs.get(self), None)
+
class VocabularyListCategoricalColumn(
CategoricalColumn,
+ fc_old._CategoricalColumn, # pylint: disable=protected-access
collections.namedtuple(
'VocabularyListCategoricalColumn',
('key', 'vocabulary_list', 'dtype', 'default_value', 'num_oov_buckets'))
@@ -2977,6 +3262,10 @@ class VocabularyListCategoricalColumn(
"""See `categorical_column_with_vocabulary_list`."""
@property
+ def _is_v2_column(self):
+ return True
+
+ @property
def name(self):
"""See `FeatureColumn` base class."""
return self.key
@@ -2986,11 +3275,14 @@ class VocabularyListCategoricalColumn(
"""See `FeatureColumn` base class."""
return {self.key: parsing_ops.VarLenFeature(self.dtype)}
- def transform_feature(self, transformation_cache, state_manager):
- """Creates a lookup table for the vocabulary list."""
- input_tensor = _to_sparse_input_and_drop_ignore_values(
- transformation_cache.get(self.key, state_manager))
+ @property
+ @deprecation.deprecated(_FEATURE_COLUMN_DEPRECATION_DATE,
+ _FEATURE_COLUMN_DEPRECATION)
+ def _parse_example_spec(self):
+ return self.parse_example_spec
+ def _transform_input_tensor(self, input_tensor):
+ """Creates a lookup table for the vocabulary list."""
if self.dtype.is_integer != input_tensor.dtype.is_integer:
raise ValueError(
'Column dtype and SparseTensors dtype must be compatible. '
@@ -3015,25 +3307,56 @@ class VocabularyListCategoricalColumn(
dtype=key_dtype,
name='{}_lookup'.format(self.key)).lookup(input_tensor)
+ def transform_feature(self, transformation_cache, state_manager):
+ """Creates a lookup table for the vocabulary list."""
+ input_tensor = _to_sparse_input_and_drop_ignore_values(
+ transformation_cache.get(self.key, state_manager))
+ return self._transform_input_tensor(input_tensor)
+
+ @deprecation.deprecated(_FEATURE_COLUMN_DEPRECATION_DATE,
+ _FEATURE_COLUMN_DEPRECATION)
+ def _transform_feature(self, inputs):
+ input_tensor = _to_sparse_input_and_drop_ignore_values(inputs.get(self.key))
+ return self._transform_input_tensor(input_tensor)
+
@property
def num_buckets(self):
"""Returns number of buckets in this sparse feature."""
return len(self.vocabulary_list) + self.num_oov_buckets
+ @property
+ @deprecation.deprecated(_FEATURE_COLUMN_DEPRECATION_DATE,
+ _FEATURE_COLUMN_DEPRECATION)
+ def _num_buckets(self):
+ return self.num_buckets
+
def get_sparse_tensors(self, transformation_cache, state_manager):
"""See `CategoricalColumn` base class."""
return CategoricalColumn.IdWeightPair(
transformation_cache.get(self, state_manager), None)
+ @deprecation.deprecated(_FEATURE_COLUMN_DEPRECATION_DATE,
+ _FEATURE_COLUMN_DEPRECATION)
+ def _get_sparse_tensors(self, inputs, weight_collections=None,
+ trainable=None):
+ del weight_collections
+ del trainable
+ return CategoricalColumn.IdWeightPair(inputs.get(self), None)
+
class IdentityCategoricalColumn(
CategoricalColumn,
+ fc_old._CategoricalColumn, # pylint: disable=protected-access
collections.namedtuple('IdentityCategoricalColumn',
('key', 'number_buckets', 'default_value'))):
"""See `categorical_column_with_identity`."""
@property
+ def _is_v2_column(self):
+ return True
+
+ @property
def name(self):
"""See `FeatureColumn` base class."""
return self.key
@@ -3043,11 +3366,14 @@ class IdentityCategoricalColumn(
"""See `FeatureColumn` base class."""
return {self.key: parsing_ops.VarLenFeature(dtypes.int64)}
- def transform_feature(self, transformation_cache, state_manager):
- """Returns a SparseTensor with identity values."""
- input_tensor = _to_sparse_input_and_drop_ignore_values(
- transformation_cache.get(self.key, state_manager))
+ @property
+ @deprecation.deprecated(_FEATURE_COLUMN_DEPRECATION_DATE,
+ _FEATURE_COLUMN_DEPRECATION)
+ def _parse_example_spec(self):
+ return self.parse_example_spec
+ def _transform_input_tensor(self, input_tensor):
+ """Returns a SparseTensor with identity values."""
if not input_tensor.dtype.is_integer:
raise ValueError(
'Invalid input, not integer. key: {} dtype: {}'.format(
@@ -3082,25 +3408,57 @@ class IdentityCategoricalColumn(
values=values,
dense_shape=input_tensor.dense_shape)
+ def transform_feature(self, transformation_cache, state_manager):
+ """Returns a SparseTensor with identity values."""
+ input_tensor = _to_sparse_input_and_drop_ignore_values(
+ transformation_cache.get(self.key, state_manager))
+ return self._transform_input_tensor(input_tensor)
+
+ @deprecation.deprecated(_FEATURE_COLUMN_DEPRECATION_DATE,
+ _FEATURE_COLUMN_DEPRECATION)
+ def _transform_feature(self, inputs):
+ input_tensor = _to_sparse_input_and_drop_ignore_values(inputs.get(self.key))
+ return self._transform_input_tensor(input_tensor)
+
@property
def num_buckets(self):
"""Returns number of buckets in this sparse feature."""
return self.number_buckets
+ @property
+ @deprecation.deprecated(_FEATURE_COLUMN_DEPRECATION_DATE,
+ _FEATURE_COLUMN_DEPRECATION)
+ def _num_buckets(self):
+ return self.num_buckets
+
def get_sparse_tensors(self, transformation_cache, state_manager):
"""See `CategoricalColumn` base class."""
return CategoricalColumn.IdWeightPair(
transformation_cache.get(self, state_manager), None)
+ @deprecation.deprecated(_FEATURE_COLUMN_DEPRECATION_DATE,
+ _FEATURE_COLUMN_DEPRECATION)
+ def _get_sparse_tensors(self, inputs, weight_collections=None,
+ trainable=None):
+ del weight_collections
+ del trainable
+ return CategoricalColumn.IdWeightPair(inputs.get(self), None)
+
class WeightedCategoricalColumn(
CategoricalColumn,
+ fc_old._CategoricalColumn, # pylint: disable=protected-access
collections.namedtuple(
'WeightedCategoricalColumn',
('categorical_column', 'weight_feature_key', 'dtype'))):
"""See `weighted_categorical_column`."""
@property
+ def _is_v2_column(self):
+ return (isinstance(self.categorical_column, FeatureColumn) and
+ self.categorical_column._is_v2_column) # pylint: disable=protected-access
+
+ @property
def name(self):
"""See `FeatureColumn` base class."""
return '{}_weighted_by_{}'.format(
@@ -3117,14 +3475,28 @@ class WeightedCategoricalColumn(
return config
@property
+ @deprecation.deprecated(_FEATURE_COLUMN_DEPRECATION_DATE,
+ _FEATURE_COLUMN_DEPRECATION)
+ def _parse_example_spec(self):
+ config = self.categorical_column._parse_example_spec # pylint: disable=protected-access
+ if self.weight_feature_key in config:
+ raise ValueError('Parse config {} already exists for {}.'.format(
+ config[self.weight_feature_key], self.weight_feature_key))
+ config[self.weight_feature_key] = parsing_ops.VarLenFeature(self.dtype)
+ return config
+
+ @property
def num_buckets(self):
"""See `DenseColumn` base class."""
return self.categorical_column.num_buckets
- def transform_feature(self, transformation_cache, state_manager):
- """Applies weights to tensor generated from `categorical_column`'."""
- weight_tensor = transformation_cache.get(self.weight_feature_key,
- state_manager)
+ @property
+ @deprecation.deprecated(_FEATURE_COLUMN_DEPRECATION_DATE,
+ _FEATURE_COLUMN_DEPRECATION)
+ def _num_buckets(self):
+ return self.categorical_column._num_buckets # pylint: disable=protected-access
+
+ def _transform_weight_tensor(self, weight_tensor):
if weight_tensor is None:
raise ValueError('Missing weights {}.'.format(self.weight_feature_key))
weight_tensor = sparse_tensor_lib.convert_to_tensor_or_sparse_tensor(
@@ -3138,27 +3510,63 @@ class WeightedCategoricalColumn(
weight_tensor, ignore_value=0.0)
if not weight_tensor.dtype.is_floating:
weight_tensor = math_ops.to_float(weight_tensor)
+ return weight_tensor
+
+ def transform_feature(self, transformation_cache, state_manager):
+ """Applies weights to tensor generated from `categorical_column`'."""
+ weight_tensor = transformation_cache.get(self.weight_feature_key,
+ state_manager)
+ weight_tensor = self._transform_weight_tensor(weight_tensor)
return (transformation_cache.get(self.categorical_column, state_manager),
weight_tensor)
+ @deprecation.deprecated(_FEATURE_COLUMN_DEPRECATION_DATE,
+ _FEATURE_COLUMN_DEPRECATION)
+ def _transform_feature(self, inputs):
+ """Applies weights to tensor generated from `categorical_column`'."""
+ weight_tensor = inputs.get(self.weight_feature_key)
+ weight_tensor = self._transform_weight_tensor(weight_tensor)
+ return (inputs.get(self.categorical_column), weight_tensor)
+
def get_sparse_tensors(self, transformation_cache, state_manager):
"""See `CategoricalColumn` base class."""
tensors = transformation_cache.get(self, state_manager)
return CategoricalColumn.IdWeightPair(tensors[0], tensors[1])
+ @deprecation.deprecated(_FEATURE_COLUMN_DEPRECATION_DATE,
+ _FEATURE_COLUMN_DEPRECATION)
+ def _get_sparse_tensors(self, inputs, weight_collections=None,
+ trainable=None):
+ del weight_collections
+ del trainable
+ tensors = inputs.get(self)
+ return CategoricalColumn.IdWeightPair(tensors[0], tensors[1])
+
class CrossedColumn(
CategoricalColumn,
+ fc_old._CategoricalColumn, # pylint: disable=protected-access
collections.namedtuple('CrossedColumn',
('keys', 'hash_bucket_size', 'hash_key'))):
"""See `crossed_column`."""
@property
+ def _is_v2_column(self):
+ for key in _collect_leaf_level_keys(self):
+ if isinstance(key, six.string_types):
+ continue
+ if not isinstance(key, FeatureColumn):
+ return False
+ if not key._is_v2_column: # pylint: disable=protected-access
+ return False
+ return True
+
+ @property
def name(self):
"""See `FeatureColumn` base class."""
feature_names = []
for key in _collect_leaf_level_keys(self):
- if isinstance(key, FeatureColumn):
+ if isinstance(key, (FeatureColumn, fc_old._FeatureColumn)): # pylint: disable=protected-access
feature_names.append(key.name)
else: # key must be a string
feature_names.append(key)
@@ -3171,17 +3579,25 @@ class CrossedColumn(
for key in self.keys:
if isinstance(key, FeatureColumn):
config.update(key.parse_example_spec)
+ elif isinstance(key, fc_old._FeatureColumn): # pylint: disable=protected-access
+ config.update(key._parse_example_spec) # pylint: disable=protected-access
else: # key must be a string
config.update({key: parsing_ops.VarLenFeature(dtypes.string)})
return config
+ @property
+ @deprecation.deprecated(_FEATURE_COLUMN_DEPRECATION_DATE,
+ _FEATURE_COLUMN_DEPRECATION)
+ def _parse_example_spec(self):
+ return self.parse_example_spec
+
def transform_feature(self, transformation_cache, state_manager):
"""Generates a hashed sparse cross from the input tensors."""
feature_tensors = []
for key in _collect_leaf_level_keys(self):
if isinstance(key, six.string_types):
feature_tensors.append(transformation_cache.get(key, state_manager))
- elif isinstance(key, CategoricalColumn):
+ elif isinstance(key, (fc_old._CategoricalColumn, CategoricalColumn)): # pylint: disable=protected-access
ids_and_weights = key.get_sparse_tensors(transformation_cache,
state_manager)
if ids_and_weights.weight_tensor is not None:
@@ -3197,16 +3613,54 @@ class CrossedColumn(
num_buckets=self.hash_bucket_size,
hash_key=self.hash_key)
+ @deprecation.deprecated(_FEATURE_COLUMN_DEPRECATION_DATE,
+ _FEATURE_COLUMN_DEPRECATION)
+ def _transform_feature(self, inputs):
+ """Generates a hashed sparse cross from the input tensors."""
+ feature_tensors = []
+ for key in _collect_leaf_level_keys(self):
+ if isinstance(key, six.string_types):
+ feature_tensors.append(inputs.get(key))
+ elif isinstance(key, (CategoricalColumn, fc_old._CategoricalColumn)): # pylint: disable=protected-access
+ ids_and_weights = key._get_sparse_tensors(inputs) # pylint: disable=protected-access
+ if ids_and_weights.weight_tensor is not None:
+ raise ValueError(
+ 'crossed_column does not support weight_tensor, but the given '
+ 'column populates weight_tensor. '
+ 'Given column: {}'.format(key.name))
+ feature_tensors.append(ids_and_weights.id_tensor)
+ else:
+ raise ValueError('Unsupported column type. Given: {}'.format(key))
+ return sparse_ops.sparse_cross_hashed(
+ inputs=feature_tensors,
+ num_buckets=self.hash_bucket_size,
+ hash_key=self.hash_key)
+
@property
def num_buckets(self):
"""Returns number of buckets in this sparse feature."""
return self.hash_bucket_size
+ @property
+ @deprecation.deprecated(_FEATURE_COLUMN_DEPRECATION_DATE,
+ _FEATURE_COLUMN_DEPRECATION)
+ def _num_buckets(self):
+ return self.num_buckets
+
def get_sparse_tensors(self, transformation_cache, state_manager):
"""See `CategoricalColumn` base class."""
return CategoricalColumn.IdWeightPair(
transformation_cache.get(self, state_manager), None)
+ @deprecation.deprecated(_FEATURE_COLUMN_DEPRECATION_DATE,
+ _FEATURE_COLUMN_DEPRECATION)
+ def _get_sparse_tensors(self, inputs, weight_collections=None,
+ trainable=None):
+ """See `CategoricalColumn` base class."""
+ del weight_collections
+ del trainable
+ return CategoricalColumn.IdWeightPair(inputs.get(self), None)
+
def _collect_leaf_level_keys(cross):
"""Collects base keys by expanding all nested crosses.
@@ -3382,9 +3836,12 @@ def _prune_invalid_weights(sparse_ids, sparse_weights):
return sparse_ids, sparse_weights
-class IndicatorColumn(DenseColumn, SequenceDenseColumn,
- collections.namedtuple('IndicatorColumn',
- ('categorical_column'))):
+class IndicatorColumn(
+ DenseColumn,
+ SequenceDenseColumn,
+ fc_old._DenseColumn, # pylint: disable=protected-access
+ fc_old._SequenceDenseColumn, # pylint: disable=protected-access
+ collections.namedtuple('IndicatorColumn', ('categorical_column'))):
"""Represents a one-hot column for use in deep networks.
Args:
@@ -3393,27 +3850,16 @@ class IndicatorColumn(DenseColumn, SequenceDenseColumn,
"""
@property
+ def _is_v2_column(self):
+ return (isinstance(self.categorical_column, FeatureColumn) and
+ self.categorical_column._is_v2_column) # pylint: disable=protected-access
+
+ @property
def name(self):
"""See `FeatureColumn` base class."""
return '{}_indicator'.format(self.categorical_column.name)
- def transform_feature(self, transformation_cache, state_manager):
- """Returns dense `Tensor` representing feature.
-
- Args:
- transformation_cache: A `FeatureTransformationCache` object to access
- features.
- state_manager: A `StateManager` to create / access resources such as
- lookup tables.
-
- Returns:
- Transformed feature `Tensor`.
-
- Raises:
- ValueError: if input rank is not known at graph building time.
- """
- id_weight_pair = self.categorical_column.get_sparse_tensors(
- transformation_cache, state_manager)
+ def _transform_id_weight_pair(self, id_weight_pair):
id_tensor = id_weight_pair.id_tensor
weight_tensor = id_weight_pair.weight_tensor
@@ -3422,7 +3868,7 @@ class IndicatorColumn(DenseColumn, SequenceDenseColumn,
weighted_column = sparse_ops.sparse_merge(
sp_ids=id_tensor,
sp_values=weight_tensor,
- vocab_size=int(self.variable_shape[-1]))
+ vocab_size=int(self._variable_shape[-1]))
# Remove (?, -1) index
weighted_column = sparse_ops.sparse_slice(weighted_column, [0, 0],
weighted_column.dense_shape)
@@ -3435,22 +3881,62 @@ class IndicatorColumn(DenseColumn, SequenceDenseColumn,
# input_layer are float32.
one_hot_id_tensor = array_ops.one_hot(
dense_id_tensor,
- depth=self.variable_shape[-1],
+ depth=self._variable_shape[-1],
on_value=1.0,
off_value=0.0)
# Reduce to get a multi-hot per example.
return math_ops.reduce_sum(one_hot_id_tensor, axis=[-2])
+ def transform_feature(self, transformation_cache, state_manager):
+ """Returns dense `Tensor` representing feature.
+
+ Args:
+ transformation_cache: A `FeatureTransformationCache` object to access
+ features.
+ state_manager: A `StateManager` to create / access resources such as
+ lookup tables.
+
+ Returns:
+ Transformed feature `Tensor`.
+
+ Raises:
+ ValueError: if input rank is not known at graph building time.
+ """
+ id_weight_pair = self.categorical_column.get_sparse_tensors(
+ transformation_cache, state_manager)
+ return self._transform_id_weight_pair(id_weight_pair)
+
+ @deprecation.deprecated(_FEATURE_COLUMN_DEPRECATION_DATE,
+ _FEATURE_COLUMN_DEPRECATION)
+ def _transform_feature(self, inputs):
+ id_weight_pair = self.categorical_column._get_sparse_tensors(inputs) # pylint: disable=protected-access
+ return self._transform_id_weight_pair(id_weight_pair)
+
@property
def parse_example_spec(self):
"""See `FeatureColumn` base class."""
return self.categorical_column.parse_example_spec
@property
+ @deprecation.deprecated(_FEATURE_COLUMN_DEPRECATION_DATE,
+ _FEATURE_COLUMN_DEPRECATION)
+ def _parse_example_spec(self):
+ return self.categorical_column._parse_example_spec # pylint: disable=protected-access
+
+ @property
def variable_shape(self):
"""Returns a `TensorShape` representing the shape of the dense `Tensor`."""
- return tensor_shape.TensorShape([1, self.categorical_column.num_buckets])
+ if isinstance(self.categorical_column, FeatureColumn):
+ return tensor_shape.TensorShape([1, self.categorical_column.num_buckets])
+ else:
+ return tensor_shape.TensorShape([1, self.categorical_column._num_buckets]) # pylint: disable=protected-access
+
+ @property
+ @deprecation.deprecated(_FEATURE_COLUMN_DEPRECATION_DATE,
+ _FEATURE_COLUMN_DEPRECATION)
+ def _variable_shape(self):
+ return tensor_shape.TensorShape([1, self.categorical_column._num_buckets]) # pylint: disable=protected-access
def get_dense_tensor(self, transformation_cache, state_manager):
"""Returns dense `Tensor` representing feature.
@@ -3481,6 +3967,27 @@ class IndicatorColumn(DenseColumn, SequenceDenseColumn,
# representation created by transform_feature.
return transformation_cache.get(self, state_manager)
+ @deprecation.deprecated(_FEATURE_COLUMN_DEPRECATION_DATE,
+ _FEATURE_COLUMN_DEPRECATION)
+ def _get_dense_tensor(self, inputs, weight_collections=None, trainable=None):
+ del weight_collections
+ del trainable
+ if isinstance(
+ self.categorical_column,
+ (SequenceCategoricalColumn, fc_old._SequenceCategoricalColumn)): # pylint: disable=protected-access
+ raise ValueError(
+ 'In indicator_column: {}. '
+ 'categorical_column must not be of type _SequenceCategoricalColumn. '
+ 'Suggested fix A: If you wish to use input_layer, use a '
+ 'non-sequence categorical_column_with_*. '
+ 'Suggested fix B: If you wish to create sequence input, use '
+ 'sequence_input_layer instead of input_layer. '
+ 'Given (type {}): {}'.format(self.name, type(self.categorical_column),
+ self.categorical_column))
+ # Feature has been already transformed. Return the intermediate
+ # representation created by transform_feature.
+ return inputs.get(self)
+
def get_sequence_dense_tensor(self, transformation_cache, state_manager):
"""See `SequenceDenseColumn` base class."""
if not isinstance(self.categorical_column, SequenceCategoricalColumn):
@@ -3496,7 +4003,36 @@ class IndicatorColumn(DenseColumn, SequenceDenseColumn,
dense_tensor = transformation_cache.get(self, state_manager)
sparse_tensors = self.categorical_column.get_sparse_tensors(
transformation_cache, state_manager)
- sequence_length = _sequence_length_from_sparse_tensor(
+ sequence_length = fc_old._sequence_length_from_sparse_tensor( # pylint: disable=protected-access
+ sparse_tensors.id_tensor)
+ return SequenceDenseColumn.TensorSequenceLengthPair(
+ dense_tensor=dense_tensor, sequence_length=sequence_length)
+
+ @deprecation.deprecated(_FEATURE_COLUMN_DEPRECATION_DATE,
+ _FEATURE_COLUMN_DEPRECATION)
+ def _get_sequence_dense_tensor(self,
+ inputs,
+ weight_collections=None,
+ trainable=None):
+ # Do nothing with weight_collections and trainable since no variables are
+ # created in this function.
+ del weight_collections
+ del trainable
+ if not isinstance(
+ self.categorical_column,
+ (SequenceCategoricalColumn, fc_old._SequenceCategoricalColumn)): # pylint: disable=protected-access
+ raise ValueError(
+ 'In indicator_column: {}. '
+ 'categorical_column must be of type _SequenceCategoricalColumn '
+ 'to use sequence_input_layer. '
+ 'Suggested fix: Use one of sequence_categorical_column_with_*. '
+ 'Given (type {}): {}'.format(self.name, type(self.categorical_column),
+ self.categorical_column))
+ # Feature has been already transformed. Return the intermediate
+ # representation created by _transform_feature.
+ dense_tensor = inputs.get(self)
+ sparse_tensors = self.categorical_column._get_sparse_tensors(inputs) # pylint: disable=protected-access
+ sequence_length = fc_old._sequence_length_from_sparse_tensor( # pylint: disable=protected-access
sparse_tensors.id_tensor)
return SequenceDenseColumn.TensorSequenceLengthPair(
dense_tensor=dense_tensor, sequence_length=sequence_length)
@@ -3518,28 +4054,19 @@ def _verify_static_batch_size_equality(tensors, columns):
expected_batch_size, tensors[i].shape[0]))
-def _sequence_length_from_sparse_tensor(sp_tensor, num_elements=1):
- """Returns a [batch_size] Tensor with per-example sequence length."""
- with ops.name_scope(None, 'sequence_length') as name_scope:
- row_ids = sp_tensor.indices[:, 0]
- column_ids = sp_tensor.indices[:, 1]
- column_ids += array_ops.ones_like(column_ids)
- seq_length = math_ops.to_int64(
- math_ops.segment_max(column_ids, segment_ids=row_ids) / num_elements)
- # If the last n rows do not have ids, seq_length will have shape
- # [batch_size - n]. Pad the remaining values with zeros.
- n_pad = array_ops.shape(sp_tensor)[:1] - array_ops.shape(seq_length)[:1]
- padding = array_ops.zeros(n_pad, dtype=seq_length.dtype)
- return array_ops.concat([seq_length, padding], axis=0, name=name_scope)
-
-
-class SequenceCategoricalColumn(FeatureColumn,
- collections.namedtuple(
- 'SequenceCategoricalColumn',
- ('categorical_column'))):
+class SequenceCategoricalColumn(
+ FeatureColumn,
+ fc_old._CategoricalColumn, # pylint: disable=protected-access
+ collections.namedtuple('SequenceCategoricalColumn',
+ ('categorical_column'))):
"""Represents sequences of categorical data."""
@property
+ def _is_v2_column(self):
+ return (isinstance(self.categorical_column, FeatureColumn) and
+ self.categorical_column._is_v2_column) # pylint: disable=protected-access
+
+ @property
def name(self):
"""See `FeatureColumn` base class."""
return self.categorical_column.name
@@ -3549,16 +4076,46 @@ class SequenceCategoricalColumn(FeatureColumn,
"""See `FeatureColumn` base class."""
return self.categorical_column.parse_example_spec
+ @property
+ @deprecation.deprecated(_FEATURE_COLUMN_DEPRECATION_DATE,
+ _FEATURE_COLUMN_DEPRECATION)
+ def _parse_example_spec(self):
+ return self.categorical_column._parse_example_spec # pylint: disable=protected-access
+
def transform_feature(self, transformation_cache, state_manager):
"""See `FeatureColumn` base class."""
return self.categorical_column.transform_feature(transformation_cache,
state_manager)
+ @deprecation.deprecated(_FEATURE_COLUMN_DEPRECATION_DATE,
+ _FEATURE_COLUMN_DEPRECATION)
+ def _transform_feature(self, inputs):
+ return self.categorical_column._transform_feature(inputs) # pylint: disable=protected-access
+
@property
def num_buckets(self):
"""Returns number of buckets in this sparse feature."""
return self.categorical_column.num_buckets
+ @property
+ @deprecation.deprecated(_FEATURE_COLUMN_DEPRECATION_DATE,
+ _FEATURE_COLUMN_DEPRECATION)
+ def _num_buckets(self):
+ return self.categorical_column._num_buckets # pylint: disable=protected-access
+
+ def _get_sparse_tensors_helper(self, sparse_tensors):
+ id_tensor = sparse_tensors.id_tensor
+ weight_tensor = sparse_tensors.weight_tensor
+ # Expands third dimension, if necessary so that embeddings are not
+ # combined during embedding lookup. If the tensor is already 3D, leave
+ # as-is.
+ shape = array_ops.shape(id_tensor)
+ target_shape = [shape[0], shape[1], -1]
+ id_tensor = sparse_ops.sparse_reshape(id_tensor, target_shape)
+ if weight_tensor is not None:
+ weight_tensor = sparse_ops.sparse_reshape(weight_tensor, target_shape)
+ return CategoricalColumn.IdWeightPair(id_tensor, weight_tensor)
+
def get_sequence_sparse_tensors(self, transformation_cache, state_manager):
"""Returns an IdWeightPair.
@@ -3580,27 +4137,11 @@ class SequenceCategoricalColumn(FeatureColumn,
"""
sparse_tensors = self.categorical_column.get_sparse_tensors(
transformation_cache, state_manager)
- id_tensor = sparse_tensors.id_tensor
- weight_tensor = sparse_tensors.weight_tensor
- # Expands final dimension, so that embeddings are not combined during
- # embedding lookup.
- check_id_rank = check_ops.assert_equal(
- array_ops.rank(id_tensor), 2,
- data=[
- 'Column {} expected ID tensor of rank 2. '.format(self.name),
- 'id_tensor shape: ', array_ops.shape(id_tensor)])
- with ops.control_dependencies([check_id_rank]):
- id_tensor = sparse_ops.sparse_reshape(
- id_tensor,
- shape=array_ops.concat([id_tensor.dense_shape, [1]], axis=0))
- if weight_tensor is not None:
- check_weight_rank = check_ops.assert_equal(
- array_ops.rank(weight_tensor), 2,
- data=[
- 'Column {} expected weight tensor of rank 2.'.format(self.name),
- 'weight_tensor shape:', array_ops.shape(weight_tensor)])
- with ops.control_dependencies([check_weight_rank]):
- weight_tensor = sparse_ops.sparse_reshape(
- weight_tensor,
- shape=array_ops.concat([weight_tensor.dense_shape, [1]], axis=0))
- return CategoricalColumn.IdWeightPair(id_tensor, weight_tensor)
+ return self._get_sparse_tensors_helper(sparse_tensors)
+
+ @deprecation.deprecated(_FEATURE_COLUMN_DEPRECATION_DATE,
+ _FEATURE_COLUMN_DEPRECATION)
+ def _get_sparse_tensors(self, inputs, weight_collections=None,
+ trainable=None):
+ sparse_tensors = self.categorical_column._get_sparse_tensors(inputs) # pylint: disable=protected-access
+ return self._get_sparse_tensors_helper(sparse_tensors)