aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/python/feature_column/feature_column_v2.py
diff options
context:
space:
mode:
authorGravatar Rohan Jain <rohanj@google.com>2018-10-09 08:16:49 -0700
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2018-10-09 08:21:39 -0700
commitcadcacc6224bcbb8a05bf3b70d625d9024a9c0f3 (patch)
treefe73a2d1ed500dbd1e5b0f6f20229e534f813d90 /tensorflow/python/feature_column/feature_column_v2.py
parenta0ed9452d5c7f897e26788d8dca5164cb6fba023 (diff)
Allowing for mixture of V1 and V2 feature columns usage in canned estimators. This is required for TF hub use cases where users might send in new feature columns to old model code. Implemented this support by making V2 feature columns support the V1 API. This is needed temporarily and would definitely be removed by TF 2.0, possibly earlier depending on what guarantees are provided by TF hub.
The only case we don't allow here is mixing in V2 shared embedding columns with V1 Feature columns. V2 Shared FC's depend on a SharedEmbeddingState manager that would have to be passed in to the various API's and there wasn't really a very clean way to make that work. Mixing V2 feature columns with V1 shared embedding columns is fine though and along with all other combinations PiperOrigin-RevId: 216359041
Diffstat (limited to 'tensorflow/python/feature_column/feature_column_v2.py')
-rw-r--r--tensorflow/python/feature_column/feature_column_v2.py869
1 files changed, 705 insertions, 164 deletions
diff --git a/tensorflow/python/feature_column/feature_column_v2.py b/tensorflow/python/feature_column/feature_column_v2.py
index b79373c475..6d089de991 100644
--- a/tensorflow/python/feature_column/feature_column_v2.py
+++ b/tensorflow/python/feature_column/feature_column_v2.py
@@ -136,6 +136,7 @@ import six
from tensorflow.python.eager import context
+from tensorflow.python.feature_column import feature_column as fc_old
from tensorflow.python.framework import dtypes
from tensorflow.python.framework import ops
from tensorflow.python.framework import sparse_tensor as sparse_tensor_lib
@@ -157,9 +158,16 @@ from tensorflow.python.ops import variables
from tensorflow.python.platform import gfile
from tensorflow.python.platform import tf_logging as logging
from tensorflow.python.training import checkpoint_utils
+from tensorflow.python.util import deprecation
from tensorflow.python.util import nest
+_FEATURE_COLUMN_DEPRECATION_DATE = '2018-11-30'
+_FEATURE_COLUMN_DEPRECATION = ('The old _FeatureColumn APIs are being '
+ 'deprecated. Please use the new FeatureColumn '
+ 'APIs instead.')
+
+
class StateManager(object):
"""Manages the state associated with FeatureColumns.
@@ -440,10 +448,6 @@ class FeatureLayer(Layer):
return (input_shape[0], total_elements)
-def _strip_leading_slashes(name):
- return name.rsplit('/', 1)[-1]
-
-
class LinearModel(Layer):
"""Produces a linear prediction `Tensor` based on given `feature_columns`.
@@ -775,12 +779,12 @@ def embedding_column(
categorical_column, dimension, combiner='mean', initializer=None,
ckpt_to_load_from=None, tensor_name_in_ckpt=None, max_norm=None,
trainable=True):
- """`_DenseColumn` that converts from sparse, categorical input.
+ """`DenseColumn` that converts from sparse, categorical input.
Use this when your inputs are sparse, but you want to convert them to a dense
representation (e.g., to feed to a DNN).
- Inputs must be a `_CategoricalColumn` created by any of the
+ Inputs must be a `CategoricalColumn` created by any of the
`categorical_column_*` function. Here is an example of using
`embedding_column` with `DNNClassifier`:
@@ -814,12 +818,12 @@ def embedding_column(
```
Args:
- categorical_column: A `_CategoricalColumn` created by a
+ categorical_column: A `CategoricalColumn` created by a
`categorical_column_with_*` function. This column produces the sparse IDs
that are inputs to the embedding lookup.
dimension: An integer specifying dimension of the embedding, must be > 0.
- combiner: A string specifying how to reduce if there are multiple entries
- in a single row. Currently 'mean', 'sqrtn' and 'sum' are supported, with
+ combiner: A string specifying how to reduce if there are multiple entries in
+ a single row. Currently 'mean', 'sqrtn' and 'sum' are supported, with
'mean' the default. 'sqrtn' often achieves good accuracy, in particular
with bag-of-words columns. Each of this can be thought as example level
normalizations on the column. For more information, see
@@ -830,14 +834,14 @@ def embedding_column(
`1/sqrt(dimension)`.
ckpt_to_load_from: String representing checkpoint name/pattern from which to
restore column weights. Required if `tensor_name_in_ckpt` is not `None`.
- tensor_name_in_ckpt: Name of the `Tensor` in `ckpt_to_load_from` from
- which to restore the column weights. Required if `ckpt_to_load_from` is
- not `None`.
+ tensor_name_in_ckpt: Name of the `Tensor` in `ckpt_to_load_from` from which
+ to restore the column weights. Required if `ckpt_to_load_from` is not
+ `None`.
max_norm: If not `None`, embedding values are l2-normalized to this value.
trainable: Whether or not the embedding is trainable. Default is True.
Returns:
- `_DenseColumn` that converts from sparse input.
+ `DenseColumn` that converts from sparse input.
Raises:
ValueError: if `dimension` not > 0.
@@ -1181,7 +1185,7 @@ def bucketized_column(source_column, boundaries):
one-dimensional.
ValueError: If `boundaries` is not a sorted list or tuple.
"""
- if not isinstance(source_column, NumericColumn):
+ if not isinstance(source_column, (NumericColumn, fc_old._NumericColumn)): # pylint: disable=protected-access
raise ValueError(
'source_column must be a column generated with numeric_column(). '
'Given: {}'.format(source_column))
@@ -1390,7 +1394,7 @@ def categorical_column_with_vocabulary_file(key,
def categorical_column_with_vocabulary_list(
key, vocabulary_list, dtype=None, default_value=-1, num_oov_buckets=0):
- """A `_CategoricalColumn` with in-memory vocabulary.
+ """A `CategoricalColumn` with in-memory vocabulary.
Use this when your inputs are in string or integer format, and you have an
in-memory vocabulary mapping each value to an integer ID. By default,
@@ -1439,14 +1443,14 @@ def categorical_column_with_vocabulary_list(
```
Args:
- key: A unique string identifying the input feature. It is used as the
- column name and the dictionary key for feature parsing configs, feature
- `Tensor` objects, and feature columns.
+ key: A unique string identifying the input feature. It is used as the column
+ name and the dictionary key for feature parsing configs, feature `Tensor`
+ objects, and feature columns.
vocabulary_list: An ordered iterable defining the vocabulary. Each feature
is mapped to the index of its value (if present) in `vocabulary_list`.
Must be castable to `dtype`.
- dtype: The type of features. Only string and integer types are supported.
- If `None`, it will be inferred from `vocabulary_list`.
+ dtype: The type of features. Only string and integer types are supported. If
+ `None`, it will be inferred from `vocabulary_list`.
default_value: The integer ID value to return for out-of-vocabulary feature
values, defaults to `-1`. This can not be specified with a positive
`num_oov_buckets`.
@@ -1604,7 +1608,7 @@ def indicator_column(categorical_column):
def weighted_categorical_column(
categorical_column, weight_feature_key, dtype=dtypes.float32):
- """Applies weight values to a `_CategoricalColumn`.
+ """Applies weight values to a `CategoricalColumn`.
Use this when each of your sparse inputs has both an ID and a value. For
example, if you're representing text documents as a collection of word
@@ -1655,7 +1659,7 @@ def weighted_categorical_column(
the same indices and dense shape.
Args:
- categorical_column: A `_CategoricalColumn` created by
+ categorical_column: A `CategoricalColumn` created by
`categorical_column_with_*` functions.
weight_feature_key: String key for weight values.
dtype: Type of weights, such as `tf.float32`. Only float and integer weights
@@ -1788,12 +1792,13 @@ def crossed_column(keys, hash_bucket_size, hash_key=None):
'keys must be a list with length > 1. Given: {}'.format(keys))
for key in keys:
if (not isinstance(key, six.string_types) and
- not isinstance(key, CategoricalColumn)):
+ not isinstance(key, (CategoricalColumn, fc_old._CategoricalColumn))): # pylint: disable=protected-access
raise ValueError(
'Unsupported key type. All keys must be either string, or '
'categorical column except HashedCategoricalColumn. '
'Given: {}'.format(key))
- if isinstance(key, HashedCategoricalColumn):
+ if isinstance(key,
+ (HashedCategoricalColumn, fc_old._HashedCategoricalColumn)): # pylint: disable=protected-access
raise ValueError(
'categorical_column_with_hash_bucket is not supported for crossing. '
'Hashing before crossing will increase probability of collision. '
@@ -1882,6 +1887,16 @@ class FeatureColumn(object):
"""
pass
+ @abc.abstractproperty
+ def _is_v2_column(self):
+ """Returns whether this FeatureColumn is fully conformant to the new API.
+
+ This is needed for composition type cases where an EmbeddingColumn etc.
+ might take in old categorical columns as input and then we want to use the
+ old API.
+ """
+ pass
+
class DenseColumn(FeatureColumn):
"""Represents a column which can be represented as `Tensor`.
@@ -1927,6 +1942,8 @@ def is_feature_column_v2(feature_columns):
for feature_column in feature_columns:
if not isinstance(feature_column, FeatureColumn):
return False
+ if not feature_column._is_v2_column: # pylint: disable=protected-access
+ return False
return True
@@ -2202,19 +2219,6 @@ class FeatureTransformationCache(object):
# TODO(ptucker): Move to third_party/tensorflow/python/ops/sparse_ops.py
-def _shape_offsets(shape):
- """Returns moving offset for each dimension given shape."""
- offsets = []
- for dim in reversed(shape):
- if offsets:
- offsets.append(dim * offsets[-1])
- else:
- offsets.append(dim)
- offsets.reverse()
- return offsets
-
-
-# TODO(ptucker): Move to third_party/tensorflow/python/ops/sparse_ops.py
def _to_sparse_input_and_drop_ignore_values(input_tensor, ignore_value=None):
"""Converts a `Tensor` to a `SparseTensor`, dropping ignore_value cells.
@@ -2306,12 +2310,17 @@ def _normalize_feature_columns(feature_columns):
class NumericColumn(
DenseColumn,
+ fc_old._DenseColumn, # pylint: disable=protected-access
collections.namedtuple(
'NumericColumn',
('key', 'shape', 'default_value', 'dtype', 'normalizer_fn'))):
"""see `numeric_column`."""
@property
+ def _is_v2_column(self):
+ return True
+
+ @property
def name(self):
"""See `FeatureColumn` base class."""
return self.key
@@ -2325,6 +2334,27 @@ class NumericColumn(
self.default_value)
}
+ @property
+ @deprecation.deprecated(_FEATURE_COLUMN_DEPRECATION_DATE,
+ _FEATURE_COLUMN_DEPRECATION)
+ def _parse_example_spec(self):
+ return self.parse_example_spec
+
+ def _transform_input_tensor(self, input_tensor):
+ if isinstance(input_tensor, sparse_tensor_lib.SparseTensor):
+ raise ValueError(
+ 'The corresponding Tensor of numerical column must be a Tensor. '
+ 'SparseTensor is not supported. key: {}'.format(self.key))
+ if self.normalizer_fn is not None:
+ input_tensor = self.normalizer_fn(input_tensor)
+ return math_ops.to_float(input_tensor)
+
+ @deprecation.deprecated(_FEATURE_COLUMN_DEPRECATION_DATE,
+ _FEATURE_COLUMN_DEPRECATION)
+ def _transform_feature(self, inputs):
+ input_tensor = inputs.get(self.key)
+ return self._transform_input_tensor(input_tensor)
+
def transform_feature(self, transformation_cache, state_manager):
"""See `FeatureColumn` base class.
@@ -2342,19 +2372,19 @@ class NumericColumn(
ValueError: If a SparseTensor is passed in.
"""
input_tensor = transformation_cache.get(self.key, state_manager)
- if isinstance(input_tensor, sparse_tensor_lib.SparseTensor):
- raise ValueError(
- 'The corresponding Tensor of numerical column must be a Tensor. '
- 'SparseTensor is not supported. key: {}'.format(self.key))
- if self.normalizer_fn is not None:
- input_tensor = self.normalizer_fn(input_tensor)
- return math_ops.to_float(input_tensor)
+ return self._transform_input_tensor(input_tensor)
@property
def variable_shape(self):
"""See `DenseColumn` base class."""
return tensor_shape.TensorShape(self.shape)
+ @property
+ @deprecation.deprecated(_FEATURE_COLUMN_DEPRECATION_DATE,
+ _FEATURE_COLUMN_DEPRECATION)
+ def _variable_shape(self):
+ return self.variable_shape
+
def get_dense_tensor(self, transformation_cache, state_manager):
"""Returns dense `Tensor` representing numeric feature.
@@ -2371,13 +2401,29 @@ class NumericColumn(
# representation created by _transform_feature.
return transformation_cache.get(self, state_manager)
+ @deprecation.deprecated(_FEATURE_COLUMN_DEPRECATION_DATE,
+ _FEATURE_COLUMN_DEPRECATION)
+ def _get_dense_tensor(self, inputs, weight_collections=None, trainable=None):
+ del weight_collections
+ del trainable
+ return inputs.get(self)
+
-class BucketizedColumn(DenseColumn, CategoricalColumn,
- collections.namedtuple('BucketizedColumn',
- ('source_column', 'boundaries'))):
+class BucketizedColumn(
+ DenseColumn,
+ CategoricalColumn,
+ fc_old._DenseColumn, # pylint: disable=protected-access
+ fc_old._CategoricalColumn, # pylint: disable=protected-access
+ collections.namedtuple('BucketizedColumn',
+ ('source_column', 'boundaries'))):
"""See `bucketized_column`."""
@property
+ def _is_v2_column(self):
+ return (isinstance(self.source_column, FeatureColumn) and
+ self.source_column._is_v2_column) # pylint: disable=protected-access
+
+ @property
def name(self):
"""See `FeatureColumn` base class."""
return '{}_bucketized'.format(self.source_column.name)
@@ -2387,6 +2433,21 @@ class BucketizedColumn(DenseColumn, CategoricalColumn,
"""See `FeatureColumn` base class."""
return self.source_column.parse_example_spec
+ @property
+ @deprecation.deprecated(_FEATURE_COLUMN_DEPRECATION_DATE,
+ _FEATURE_COLUMN_DEPRECATION)
+ def _parse_example_spec(self):
+ return self.source_column._parse_example_spec # pylint: disable=protected-access
+
+ @deprecation.deprecated(_FEATURE_COLUMN_DEPRECATION_DATE,
+ _FEATURE_COLUMN_DEPRECATION)
+ def _transform_feature(self, inputs):
+ """Returns bucketized categorical `source_column` tensor."""
+ source_tensor = inputs.get(self.source_column)
+ return math_ops._bucketize( # pylint: disable=protected-access
+ source_tensor,
+ boundaries=self.boundaries)
+
def transform_feature(self, transformation_cache, state_manager):
"""Returns bucketized categorical `source_column` tensor."""
source_tensor = transformation_cache.get(self.source_column, state_manager)
@@ -2400,24 +2461,45 @@ class BucketizedColumn(DenseColumn, CategoricalColumn,
return tensor_shape.TensorShape(
tuple(self.source_column.shape) + (len(self.boundaries) + 1,))
- def get_dense_tensor(self, transformation_cache, state_manager):
- """Returns one hot encoded dense `Tensor`."""
- input_tensor = transformation_cache.get(self, state_manager)
+ @property
+ @deprecation.deprecated(_FEATURE_COLUMN_DEPRECATION_DATE,
+ _FEATURE_COLUMN_DEPRECATION)
+ def _variable_shape(self):
+ return self.variable_shape
+
+ def _get_dense_tensor_for_input_tensor(self, input_tensor):
return array_ops.one_hot(
indices=math_ops.to_int64(input_tensor),
depth=len(self.boundaries) + 1,
on_value=1.,
off_value=0.)
+ def get_dense_tensor(self, transformation_cache, state_manager):
+ """Returns one hot encoded dense `Tensor`."""
+ input_tensor = transformation_cache.get(self, state_manager)
+ return self._get_dense_tensor_for_input_tensor(input_tensor)
+
+ @deprecation.deprecated(_FEATURE_COLUMN_DEPRECATION_DATE,
+ _FEATURE_COLUMN_DEPRECATION)
+ def _get_dense_tensor(self, inputs, weight_collections=None, trainable=None):
+ del weight_collections
+ del trainable
+ input_tensor = inputs.get(self)
+ return self._get_dense_tensor_for_input_tensor(input_tensor)
+
@property
def num_buckets(self):
"""See `CategoricalColumn` base class."""
# By construction, source_column is always one-dimensional.
return (len(self.boundaries) + 1) * self.source_column.shape[0]
- def get_sparse_tensors(self, transformation_cache, state_manager):
- """Converts dense inputs to SparseTensor so downstream code can use it."""
- input_tensor = transformation_cache.get(self, state_manager)
+ @property
+ @deprecation.deprecated(_FEATURE_COLUMN_DEPRECATION_DATE,
+ _FEATURE_COLUMN_DEPRECATION)
+ def _num_buckets(self):
+ return self.num_buckets
+
+ def _get_sparse_tensors_for_input_tensor(self, input_tensor):
batch_size = array_ops.shape(input_tensor)[0]
# By construction, source_column is always one-dimensional.
source_dimension = self.source_column.shape[0]
@@ -2443,9 +2525,27 @@ class BucketizedColumn(DenseColumn, CategoricalColumn,
dense_shape=dense_shape)
return CategoricalColumn.IdWeightPair(sparse_tensor, None)
+ def get_sparse_tensors(self, transformation_cache, state_manager):
+ """Converts dense inputs to SparseTensor so downstream code can use it."""
+ input_tensor = transformation_cache.get(self, state_manager)
+ return self._get_sparse_tensors_for_input_tensor(input_tensor)
+
+ @deprecation.deprecated(_FEATURE_COLUMN_DEPRECATION_DATE,
+ _FEATURE_COLUMN_DEPRECATION)
+ def _get_sparse_tensors(self, inputs, weight_collections=None,
+ trainable=None):
+ """Converts dense inputs to SparseTensor so downstream code can use it."""
+ del weight_collections
+ del trainable
+ input_tensor = inputs.get(self)
+ return self._get_sparse_tensors_for_input_tensor(input_tensor)
+
class EmbeddingColumn(
- DenseColumn, SequenceDenseColumn,
+ DenseColumn,
+ SequenceDenseColumn,
+ fc_old._DenseColumn, # pylint: disable=protected-access
+ fc_old._SequenceDenseColumn, # pylint: disable=protected-access
collections.namedtuple(
'EmbeddingColumn',
('categorical_column', 'dimension', 'combiner', 'initializer',
@@ -2453,6 +2553,11 @@ class EmbeddingColumn(
"""See `embedding_column`."""
@property
+ def _is_v2_column(self):
+ return (isinstance(self.categorical_column, FeatureColumn) and
+ self.categorical_column._is_v2_column) # pylint: disable=protected-access
+
+ @property
def name(self):
"""See `FeatureColumn` base class."""
return '{}_embedding'.format(self.categorical_column.name)
@@ -2462,18 +2567,35 @@ class EmbeddingColumn(
"""See `FeatureColumn` base class."""
return self.categorical_column.parse_example_spec
+ @property
+ @deprecation.deprecated(_FEATURE_COLUMN_DEPRECATION_DATE,
+ _FEATURE_COLUMN_DEPRECATION)
+ def _parse_example_spec(self):
+ return self.categorical_column._parse_example_spec # pylint: disable=protected-access
+
def transform_feature(self, transformation_cache, state_manager):
"""Transforms underlying `categorical_column`."""
return transformation_cache.get(self.categorical_column, state_manager)
+ @deprecation.deprecated(_FEATURE_COLUMN_DEPRECATION_DATE,
+ _FEATURE_COLUMN_DEPRECATION)
+ def _transform_feature(self, inputs):
+ return inputs.get(self.categorical_column)
+
@property
def variable_shape(self):
"""See `DenseColumn` base class."""
return tensor_shape.vector(self.dimension)
+ @property
+ @deprecation.deprecated(_FEATURE_COLUMN_DEPRECATION_DATE,
+ _FEATURE_COLUMN_DEPRECATION)
+ def _variable_shape(self):
+ return self.variable_shape
+
def create_state(self, state_manager):
"""Creates the embedding lookup variable."""
- embedding_shape = (self.categorical_column.num_buckets, self.dimension)
+ embedding_shape = (self.categorical_column._num_buckets, self.dimension) # pylint: disable=protected-access
state_manager.create_variable(
self,
name='embedding_weights',
@@ -2482,17 +2604,11 @@ class EmbeddingColumn(
trainable=self.trainable,
initializer=self.initializer)
- def _get_dense_tensor_internal(self, transformation_cache, state_manager):
- """Private method that follows the signature of _get_dense_tensor."""
- # Get sparse IDs and weights.
- sparse_tensors = self.categorical_column.get_sparse_tensors(
- transformation_cache, state_manager)
+ def _get_dense_tensor_internal_helper(self, sparse_tensors,
+ embedding_weights):
sparse_ids = sparse_tensors.id_tensor
sparse_weights = sparse_tensors.weight_tensor
- embedding_weights = state_manager.get_variable(
- self, name='embedding_weights')
-
if self.ckpt_to_load_from is not None:
to_restore = embedding_weights
if isinstance(to_restore, variables.PartitionedVariable):
@@ -2510,6 +2626,30 @@ class EmbeddingColumn(
name='%s_weights' % self.name,
max_norm=self.max_norm)
+ def _get_dense_tensor_internal(self, sparse_tensors, state_manager):
+ """Private method that follows the signature of get_dense_tensor."""
+ embedding_weights = state_manager.get_variable(
+ self, name='embedding_weights')
+ return self._get_dense_tensor_internal_helper(sparse_tensors,
+ embedding_weights)
+
+ def _old_get_dense_tensor_internal(self, sparse_tensors, weight_collections,
+ trainable):
+ """Private method that follows the signature of _get_dense_tensor."""
+ embedding_shape = (self.categorical_column._num_buckets, self.dimension) # pylint: disable=protected-access
+ if (weight_collections and
+ ops.GraphKeys.GLOBAL_VARIABLES not in weight_collections):
+ weight_collections.append(ops.GraphKeys.GLOBAL_VARIABLES)
+ embedding_weights = variable_scope.get_variable(
+ name='embedding_weights',
+ shape=embedding_shape,
+ dtype=dtypes.float32,
+ initializer=self.initializer,
+ trainable=self.trainable and trainable,
+ collections=weight_collections)
+ return self._get_dense_tensor_internal_helper(sparse_tensors,
+ embedding_weights)
+
def get_dense_tensor(self, transformation_cache, state_manager):
"""Returns tensor after doing the embedding lookup.
@@ -2535,7 +2675,30 @@ class EmbeddingColumn(
'sequence_input_layer instead of input_layer. '
'Given (type {}): {}'.format(self.name, type(self.categorical_column),
self.categorical_column))
- return self._get_dense_tensor_internal(transformation_cache, state_manager)
+ # Get sparse IDs and weights.
+ sparse_tensors = self.categorical_column.get_sparse_tensors(
+ transformation_cache, state_manager)
+ return self._get_dense_tensor_internal(sparse_tensors, state_manager)
+
+ @deprecation.deprecated(_FEATURE_COLUMN_DEPRECATION_DATE,
+ _FEATURE_COLUMN_DEPRECATION)
+ def _get_dense_tensor(self, inputs, weight_collections=None, trainable=None):
+ if isinstance(
+ self.categorical_column,
+ (SequenceCategoricalColumn, fc_old._SequenceCategoricalColumn)): # pylint: disable=protected-access
+ raise ValueError(
+ 'In embedding_column: {}. '
+ 'categorical_column must not be of type _SequenceCategoricalColumn. '
+ 'Suggested fix A: If you wish to use input_layer, use a '
+ 'non-sequence categorical_column_with_*. '
+ 'Suggested fix B: If you wish to create sequence input, use '
+ 'sequence_input_layer instead of input_layer. '
+ 'Given (type {}): {}'.format(self.name, type(self.categorical_column),
+ self.categorical_column))
+ sparse_tensors = self.categorical_column._get_sparse_tensors( # pylint: disable=protected-access
+ inputs, weight_collections, trainable)
+ return self._old_get_dense_tensor_internal(sparse_tensors,
+ weight_collections, trainable)
def get_sequence_dense_tensor(self, transformation_cache, state_manager):
"""See `SequenceDenseColumn` base class."""
@@ -2547,21 +2710,40 @@ class EmbeddingColumn(
'Suggested fix: Use one of sequence_categorical_column_with_*. '
'Given (type {}): {}'.format(self.name, type(self.categorical_column),
self.categorical_column))
- dense_tensor = self._get_dense_tensor_internal( # pylint: disable=protected-access
+ sparse_tensors = self.categorical_column.get_sequence_sparse_tensors(
transformation_cache, state_manager)
- sparse_tensors = self.categorical_column.get_sparse_tensors(
- transformation_cache, state_manager)
- sequence_length = _sequence_length_from_sparse_tensor(
+ dense_tensor = self._get_dense_tensor_internal(sparse_tensors,
+ state_manager)
+ sequence_length = fc_old._sequence_length_from_sparse_tensor( # pylint: disable=protected-access
sparse_tensors.id_tensor)
return SequenceDenseColumn.TensorSequenceLengthPair(
dense_tensor=dense_tensor, sequence_length=sequence_length)
-
-def _get_graph_for_variable(var):
- if isinstance(var, variables.PartitionedVariable):
- return list(var)[0].graph
- else:
- return var.graph
+ @deprecation.deprecated(_FEATURE_COLUMN_DEPRECATION_DATE,
+ _FEATURE_COLUMN_DEPRECATION)
+ def _get_sequence_dense_tensor(self,
+ inputs,
+ weight_collections=None,
+ trainable=None):
+ if not isinstance(
+ self.categorical_column,
+ (SequenceCategoricalColumn, fc_old._SequenceCategoricalColumn)): # pylint: disable=protected-access
+ raise ValueError(
+ 'In embedding_column: {}. '
+ 'categorical_column must be of type _SequenceCategoricalColumn '
+ 'to use sequence_input_layer. '
+ 'Suggested fix: Use one of sequence_categorical_column_with_*. '
+ 'Given (type {}): {}'.format(self.name, type(self.categorical_column),
+ self.categorical_column))
+ sparse_tensors = self.categorical_column._get_sparse_tensors(inputs) # pylint: disable=protected-access
+ dense_tensor = self._old_get_dense_tensor_internal(
+ sparse_tensors,
+ weight_collections=weight_collections,
+ trainable=trainable)
+ sequence_length = fc_old._sequence_length_from_sparse_tensor( # pylint: disable=protected-access
+ sparse_tensors.id_tensor)
+ return SequenceDenseColumn.TensorSequenceLengthPair(
+ dense_tensor=dense_tensor, sequence_length=sequence_length)
class SharedEmbeddingStateManager(Layer):
@@ -2633,8 +2815,17 @@ def maybe_create_shared_state_manager(feature_columns):
return None
+def _raise_shared_embedding_column_error():
+ raise ValueError('SharedEmbeddingColumns are not supported in '
+ '`linear_model` or `input_layer`. Please use '
+ '`FeatureLayer` or `LinearModel` instead.')
+
+
class SharedEmbeddingColumn(
- DenseColumn, SequenceDenseColumn,
+ DenseColumn,
+ SequenceDenseColumn,
+ fc_old._DenseColumn, # pylint: disable=protected-access
+ fc_old._SequenceDenseColumn, # pylint: disable=protected-access
collections.namedtuple(
'SharedEmbeddingColumn',
('categorical_column', 'dimension', 'combiner', 'initializer',
@@ -2643,6 +2834,10 @@ class SharedEmbeddingColumn(
"""See `embedding_column`."""
@property
+ def _is_v2_column(self):
+ return True
+
+ @property
def name(self):
"""See `FeatureColumn` base class."""
return '{}_shared_embedding'.format(self.categorical_column.name)
@@ -2662,15 +2857,26 @@ class SharedEmbeddingColumn(
"""See `FeatureColumn` base class."""
return self.categorical_column.parse_example_spec
+ @property
+ def _parse_example_spec(self):
+ return _raise_shared_embedding_column_error()
+
def transform_feature(self, transformation_cache, state_manager):
"""See `FeatureColumn` base class."""
return transformation_cache.get(self.categorical_column, state_manager)
+ def _transform_feature(self, inputs):
+ return _raise_shared_embedding_column_error()
+
@property
def variable_shape(self):
"""See `DenseColumn` base class."""
return tensor_shape.vector(self.dimension)
+ @property
+ def _variable_shape(self):
+ return _raise_shared_embedding_column_error()
+
def create_state(self, state_manager):
"""Creates the shared embedding lookup variable."""
if not isinstance(state_manager, SharedEmbeddingStateManager):
@@ -2731,6 +2937,9 @@ class SharedEmbeddingColumn(
self.categorical_column))
return self._get_dense_tensor_internal(transformation_cache, state_manager)
+ def _get_dense_tensor(self, inputs, weight_collections=None, trainable=None):
+ return _raise_shared_embedding_column_error()
+
def get_sequence_dense_tensor(self, transformation_cache, state_manager):
"""See `SequenceDenseColumn` base class."""
if not isinstance(self.categorical_column, SequenceCategoricalColumn):
@@ -2745,11 +2954,17 @@ class SharedEmbeddingColumn(
state_manager)
sparse_tensors = self.categorical_column.get_sparse_tensors(
transformation_cache, state_manager)
- sequence_length = _sequence_length_from_sparse_tensor(
+ sequence_length = fc_old._sequence_length_from_sparse_tensor( # pylint: disable=protected-access
sparse_tensors.id_tensor)
return SequenceDenseColumn.TensorSequenceLengthPair(
dense_tensor=dense_tensor, sequence_length=sequence_length)
+ def _get_sequence_dense_tensor(self,
+ inputs,
+ weight_collections=None,
+ trainable=None):
+ return _raise_shared_embedding_column_error()
+
def _create_tuple(shape, value):
"""Returns a tuple with given shape and filled with value."""
@@ -2858,11 +3073,16 @@ def _check_default_value(shape, default_value, dtype, key):
class HashedCategoricalColumn(
CategoricalColumn,
+ fc_old._CategoricalColumn, # pylint: disable=protected-access
collections.namedtuple('HashedCategoricalColumn',
('key', 'hash_bucket_size', 'dtype'))):
"""see `categorical_column_with_hash_bucket`."""
@property
+ def _is_v2_column(self):
+ return True
+
+ @property
def name(self):
"""See `FeatureColumn` base class."""
return self.key
@@ -2872,10 +3092,14 @@ class HashedCategoricalColumn(
"""See `FeatureColumn` base class."""
return {self.key: parsing_ops.VarLenFeature(self.dtype)}
- def transform_feature(self, transformation_cache, state_manager):
+ @property
+ @deprecation.deprecated(_FEATURE_COLUMN_DEPRECATION_DATE,
+ _FEATURE_COLUMN_DEPRECATION)
+ def _parse_example_spec(self):
+ return self.parse_example_spec
+
+ def _transform_input_tensor(self, input_tensor):
"""Hashes the values in the feature_column."""
- input_tensor = _to_sparse_input_and_drop_ignore_values(
- transformation_cache.get(self.key, state_manager))
if not isinstance(input_tensor, sparse_tensor_lib.SparseTensor):
raise ValueError('SparseColumn input must be a SparseTensor.')
@@ -2899,25 +3123,56 @@ class HashedCategoricalColumn(
return sparse_tensor_lib.SparseTensor(
input_tensor.indices, sparse_id_values, input_tensor.dense_shape)
+ def transform_feature(self, transformation_cache, state_manager):
+ """Hashes the values in the feature_column."""
+ input_tensor = _to_sparse_input_and_drop_ignore_values(
+ transformation_cache.get(self.key, state_manager))
+ return self._transform_input_tensor(input_tensor)
+
+ @deprecation.deprecated(_FEATURE_COLUMN_DEPRECATION_DATE,
+ _FEATURE_COLUMN_DEPRECATION)
+ def _transform_feature(self, inputs):
+ input_tensor = _to_sparse_input_and_drop_ignore_values(inputs.get(self.key))
+ return self._transform_input_tensor(input_tensor)
+
@property
def num_buckets(self):
"""Returns number of buckets in this sparse feature."""
return self.hash_bucket_size
+ @property
+ @deprecation.deprecated(_FEATURE_COLUMN_DEPRECATION_DATE,
+ _FEATURE_COLUMN_DEPRECATION)
+ def _num_buckets(self):
+ return self.num_buckets
+
def get_sparse_tensors(self, transformation_cache, state_manager):
"""See `CategoricalColumn` base class."""
return CategoricalColumn.IdWeightPair(
transformation_cache.get(self, state_manager), None)
+ @deprecation.deprecated(_FEATURE_COLUMN_DEPRECATION_DATE,
+ _FEATURE_COLUMN_DEPRECATION)
+ def _get_sparse_tensors(self, inputs, weight_collections=None,
+ trainable=None):
+ del weight_collections
+ del trainable
+ return CategoricalColumn.IdWeightPair(inputs.get(self), None)
+
class VocabularyFileCategoricalColumn(
CategoricalColumn,
+ fc_old._CategoricalColumn, # pylint: disable=protected-access
collections.namedtuple('VocabularyFileCategoricalColumn',
('key', 'vocabulary_file', 'vocabulary_size',
'num_oov_buckets', 'dtype', 'default_value'))):
"""See `categorical_column_with_vocabulary_file`."""
@property
+ def _is_v2_column(self):
+ return True
+
+ @property
def name(self):
"""See `FeatureColumn` base class."""
return self.key
@@ -2927,11 +3182,14 @@ class VocabularyFileCategoricalColumn(
"""See `FeatureColumn` base class."""
return {self.key: parsing_ops.VarLenFeature(self.dtype)}
- def transform_feature(self, transformation_cache, state_manager):
- """Creates a lookup table for the vocabulary."""
- input_tensor = _to_sparse_input_and_drop_ignore_values(
- transformation_cache.get(self.key, state_manager))
+ @property
+ @deprecation.deprecated(_FEATURE_COLUMN_DEPRECATION_DATE,
+ _FEATURE_COLUMN_DEPRECATION)
+ def _parse_example_spec(self):
+ return self.parse_example_spec
+ def _transform_input_tensor(self, input_tensor):
+ """Creates a lookup table for the vocabulary."""
if self.dtype.is_integer != input_tensor.dtype.is_integer:
raise ValueError(
'Column dtype and SparseTensors dtype must be compatible. '
@@ -2957,19 +3215,46 @@ class VocabularyFileCategoricalColumn(
key_dtype=key_dtype,
name='{}_lookup'.format(self.key)).lookup(input_tensor)
+ def transform_feature(self, transformation_cache, state_manager):
+ """Creates a lookup table for the vocabulary."""
+ input_tensor = _to_sparse_input_and_drop_ignore_values(
+ transformation_cache.get(self.key, state_manager))
+ return self._transform_input_tensor(input_tensor)
+
+ @deprecation.deprecated(_FEATURE_COLUMN_DEPRECATION_DATE,
+ _FEATURE_COLUMN_DEPRECATION)
+ def _transform_feature(self, inputs):
+ input_tensor = _to_sparse_input_and_drop_ignore_values(inputs.get(self.key))
+ return self._transform_input_tensor(input_tensor)
+
@property
def num_buckets(self):
"""Returns number of buckets in this sparse feature."""
return self.vocabulary_size + self.num_oov_buckets
+ @property
+ @deprecation.deprecated(_FEATURE_COLUMN_DEPRECATION_DATE,
+ _FEATURE_COLUMN_DEPRECATION)
+ def _num_buckets(self):
+ return self.num_buckets
+
def get_sparse_tensors(self, transformation_cache, state_manager):
"""See `CategoricalColumn` base class."""
return CategoricalColumn.IdWeightPair(
transformation_cache.get(self, state_manager), None)
+ @deprecation.deprecated(_FEATURE_COLUMN_DEPRECATION_DATE,
+ _FEATURE_COLUMN_DEPRECATION)
+ def _get_sparse_tensors(self, inputs, weight_collections=None,
+ trainable=None):
+ del weight_collections
+ del trainable
+ return CategoricalColumn.IdWeightPair(inputs.get(self), None)
+
class VocabularyListCategoricalColumn(
CategoricalColumn,
+ fc_old._CategoricalColumn, # pylint: disable=protected-access
collections.namedtuple(
'VocabularyListCategoricalColumn',
('key', 'vocabulary_list', 'dtype', 'default_value', 'num_oov_buckets'))
@@ -2977,6 +3262,10 @@ class VocabularyListCategoricalColumn(
"""See `categorical_column_with_vocabulary_list`."""
@property
+ def _is_v2_column(self):
+ return True
+
+ @property
def name(self):
"""See `FeatureColumn` base class."""
return self.key
@@ -2986,11 +3275,14 @@ class VocabularyListCategoricalColumn(
"""See `FeatureColumn` base class."""
return {self.key: parsing_ops.VarLenFeature(self.dtype)}
- def transform_feature(self, transformation_cache, state_manager):
- """Creates a lookup table for the vocabulary list."""
- input_tensor = _to_sparse_input_and_drop_ignore_values(
- transformation_cache.get(self.key, state_manager))
+ @property
+ @deprecation.deprecated(_FEATURE_COLUMN_DEPRECATION_DATE,
+ _FEATURE_COLUMN_DEPRECATION)
+ def _parse_example_spec(self):
+ return self.parse_example_spec
+ def _transform_input_tensor(self, input_tensor):
+ """Creates a lookup table for the vocabulary list."""
if self.dtype.is_integer != input_tensor.dtype.is_integer:
raise ValueError(
'Column dtype and SparseTensors dtype must be compatible. '
@@ -3015,25 +3307,56 @@ class VocabularyListCategoricalColumn(
dtype=key_dtype,
name='{}_lookup'.format(self.key)).lookup(input_tensor)
+ def transform_feature(self, transformation_cache, state_manager):
+ """Creates a lookup table for the vocabulary list."""
+ input_tensor = _to_sparse_input_and_drop_ignore_values(
+ transformation_cache.get(self.key, state_manager))
+ return self._transform_input_tensor(input_tensor)
+
+ @deprecation.deprecated(_FEATURE_COLUMN_DEPRECATION_DATE,
+ _FEATURE_COLUMN_DEPRECATION)
+ def _transform_feature(self, inputs):
+ input_tensor = _to_sparse_input_and_drop_ignore_values(inputs.get(self.key))
+ return self._transform_input_tensor(input_tensor)
+
@property
def num_buckets(self):
"""Returns number of buckets in this sparse feature."""
return len(self.vocabulary_list) + self.num_oov_buckets
+ @property
+ @deprecation.deprecated(_FEATURE_COLUMN_DEPRECATION_DATE,
+ _FEATURE_COLUMN_DEPRECATION)
+ def _num_buckets(self):
+ return self.num_buckets
+
def get_sparse_tensors(self, transformation_cache, state_manager):
"""See `CategoricalColumn` base class."""
return CategoricalColumn.IdWeightPair(
transformation_cache.get(self, state_manager), None)
+ @deprecation.deprecated(_FEATURE_COLUMN_DEPRECATION_DATE,
+ _FEATURE_COLUMN_DEPRECATION)
+ def _get_sparse_tensors(self, inputs, weight_collections=None,
+ trainable=None):
+ del weight_collections
+ del trainable
+ return CategoricalColumn.IdWeightPair(inputs.get(self), None)
+
class IdentityCategoricalColumn(
CategoricalColumn,
+ fc_old._CategoricalColumn, # pylint: disable=protected-access
collections.namedtuple('IdentityCategoricalColumn',
('key', 'number_buckets', 'default_value'))):
"""See `categorical_column_with_identity`."""
@property
+ def _is_v2_column(self):
+ return True
+
+ @property
def name(self):
"""See `FeatureColumn` base class."""
return self.key
@@ -3043,11 +3366,14 @@ class IdentityCategoricalColumn(
"""See `FeatureColumn` base class."""
return {self.key: parsing_ops.VarLenFeature(dtypes.int64)}
- def transform_feature(self, transformation_cache, state_manager):
- """Returns a SparseTensor with identity values."""
- input_tensor = _to_sparse_input_and_drop_ignore_values(
- transformation_cache.get(self.key, state_manager))
+ @property
+ @deprecation.deprecated(_FEATURE_COLUMN_DEPRECATION_DATE,
+ _FEATURE_COLUMN_DEPRECATION)
+ def _parse_example_spec(self):
+ return self.parse_example_spec
+ def _transform_input_tensor(self, input_tensor):
+ """Returns a SparseTensor with identity values."""
if not input_tensor.dtype.is_integer:
raise ValueError(
'Invalid input, not integer. key: {} dtype: {}'.format(
@@ -3082,25 +3408,57 @@ class IdentityCategoricalColumn(
values=values,
dense_shape=input_tensor.dense_shape)
+ def transform_feature(self, transformation_cache, state_manager):
+ """Returns a SparseTensor with identity values."""
+ input_tensor = _to_sparse_input_and_drop_ignore_values(
+ transformation_cache.get(self.key, state_manager))
+ return self._transform_input_tensor(input_tensor)
+
+ @deprecation.deprecated(_FEATURE_COLUMN_DEPRECATION_DATE,
+ _FEATURE_COLUMN_DEPRECATION)
+ def _transform_feature(self, inputs):
+ input_tensor = _to_sparse_input_and_drop_ignore_values(inputs.get(self.key))
+ return self._transform_input_tensor(input_tensor)
+
@property
def num_buckets(self):
"""Returns number of buckets in this sparse feature."""
return self.number_buckets
+ @property
+ @deprecation.deprecated(_FEATURE_COLUMN_DEPRECATION_DATE,
+ _FEATURE_COLUMN_DEPRECATION)
+ def _num_buckets(self):
+ return self.num_buckets
+
def get_sparse_tensors(self, transformation_cache, state_manager):
"""See `CategoricalColumn` base class."""
return CategoricalColumn.IdWeightPair(
transformation_cache.get(self, state_manager), None)
+ @deprecation.deprecated(_FEATURE_COLUMN_DEPRECATION_DATE,
+ _FEATURE_COLUMN_DEPRECATION)
+ def _get_sparse_tensors(self, inputs, weight_collections=None,
+ trainable=None):
+ del weight_collections
+ del trainable
+ return CategoricalColumn.IdWeightPair(inputs.get(self), None)
+
class WeightedCategoricalColumn(
CategoricalColumn,
+ fc_old._CategoricalColumn, # pylint: disable=protected-access
collections.namedtuple(
'WeightedCategoricalColumn',
('categorical_column', 'weight_feature_key', 'dtype'))):
"""See `weighted_categorical_column`."""
@property
+ def _is_v2_column(self):
+ return (isinstance(self.categorical_column, FeatureColumn) and
+ self.categorical_column._is_v2_column) # pylint: disable=protected-access
+
+ @property
def name(self):
"""See `FeatureColumn` base class."""
return '{}_weighted_by_{}'.format(
@@ -3117,14 +3475,28 @@ class WeightedCategoricalColumn(
return config
@property
+ @deprecation.deprecated(_FEATURE_COLUMN_DEPRECATION_DATE,
+ _FEATURE_COLUMN_DEPRECATION)
+ def _parse_example_spec(self):
+ config = self.categorical_column._parse_example_spec # pylint: disable=protected-access
+ if self.weight_feature_key in config:
+ raise ValueError('Parse config {} already exists for {}.'.format(
+ config[self.weight_feature_key], self.weight_feature_key))
+ config[self.weight_feature_key] = parsing_ops.VarLenFeature(self.dtype)
+ return config
+
+ @property
def num_buckets(self):
"""See `DenseColumn` base class."""
return self.categorical_column.num_buckets
- def transform_feature(self, transformation_cache, state_manager):
- """Applies weights to tensor generated from `categorical_column`'."""
- weight_tensor = transformation_cache.get(self.weight_feature_key,
- state_manager)
+ @property
+ @deprecation.deprecated(_FEATURE_COLUMN_DEPRECATION_DATE,
+ _FEATURE_COLUMN_DEPRECATION)
+ def _num_buckets(self):
+ return self.categorical_column._num_buckets # pylint: disable=protected-access
+
+ def _transform_weight_tensor(self, weight_tensor):
if weight_tensor is None:
raise ValueError('Missing weights {}.'.format(self.weight_feature_key))
weight_tensor = sparse_tensor_lib.convert_to_tensor_or_sparse_tensor(
@@ -3138,27 +3510,63 @@ class WeightedCategoricalColumn(
weight_tensor, ignore_value=0.0)
if not weight_tensor.dtype.is_floating:
weight_tensor = math_ops.to_float(weight_tensor)
+ return weight_tensor
+
+ def transform_feature(self, transformation_cache, state_manager):
+ """Applies weights to tensor generated from `categorical_column`'."""
+ weight_tensor = transformation_cache.get(self.weight_feature_key,
+ state_manager)
+ weight_tensor = self._transform_weight_tensor(weight_tensor)
return (transformation_cache.get(self.categorical_column, state_manager),
weight_tensor)
+ @deprecation.deprecated(_FEATURE_COLUMN_DEPRECATION_DATE,
+ _FEATURE_COLUMN_DEPRECATION)
+ def _transform_feature(self, inputs):
+ """Applies weights to tensor generated from `categorical_column`'."""
+ weight_tensor = inputs.get(self.weight_feature_key)
+ weight_tensor = self._transform_weight_tensor(weight_tensor)
+ return (inputs.get(self.categorical_column), weight_tensor)
+
def get_sparse_tensors(self, transformation_cache, state_manager):
"""See `CategoricalColumn` base class."""
tensors = transformation_cache.get(self, state_manager)
return CategoricalColumn.IdWeightPair(tensors[0], tensors[1])
+ @deprecation.deprecated(_FEATURE_COLUMN_DEPRECATION_DATE,
+ _FEATURE_COLUMN_DEPRECATION)
+ def _get_sparse_tensors(self, inputs, weight_collections=None,
+ trainable=None):
+ del weight_collections
+ del trainable
+ tensors = inputs.get(self)
+ return CategoricalColumn.IdWeightPair(tensors[0], tensors[1])
+
class CrossedColumn(
CategoricalColumn,
+ fc_old._CategoricalColumn, # pylint: disable=protected-access
collections.namedtuple('CrossedColumn',
('keys', 'hash_bucket_size', 'hash_key'))):
"""See `crossed_column`."""
@property
+ def _is_v2_column(self):
+ for key in _collect_leaf_level_keys(self):
+ if isinstance(key, six.string_types):
+ continue
+ if not isinstance(key, FeatureColumn):
+ return False
+ if not key._is_v2_column: # pylint: disable=protected-access
+ return False
+ return True
+
+ @property
def name(self):
"""See `FeatureColumn` base class."""
feature_names = []
for key in _collect_leaf_level_keys(self):
- if isinstance(key, FeatureColumn):
+ if isinstance(key, (FeatureColumn, fc_old._FeatureColumn)): # pylint: disable=protected-access
feature_names.append(key.name)
else: # key must be a string
feature_names.append(key)
@@ -3171,17 +3579,25 @@ class CrossedColumn(
for key in self.keys:
if isinstance(key, FeatureColumn):
config.update(key.parse_example_spec)
+ elif isinstance(key, fc_old._FeatureColumn): # pylint: disable=protected-access
+ config.update(key._parse_example_spec) # pylint: disable=protected-access
else: # key must be a string
config.update({key: parsing_ops.VarLenFeature(dtypes.string)})
return config
+ @property
+ @deprecation.deprecated(_FEATURE_COLUMN_DEPRECATION_DATE,
+ _FEATURE_COLUMN_DEPRECATION)
+ def _parse_example_spec(self):
+ return self.parse_example_spec
+
def transform_feature(self, transformation_cache, state_manager):
"""Generates a hashed sparse cross from the input tensors."""
feature_tensors = []
for key in _collect_leaf_level_keys(self):
if isinstance(key, six.string_types):
feature_tensors.append(transformation_cache.get(key, state_manager))
- elif isinstance(key, CategoricalColumn):
+ elif isinstance(key, (fc_old._CategoricalColumn, CategoricalColumn)): # pylint: disable=protected-access
ids_and_weights = key.get_sparse_tensors(transformation_cache,
state_manager)
if ids_and_weights.weight_tensor is not None:
@@ -3197,16 +3613,54 @@ class CrossedColumn(
num_buckets=self.hash_bucket_size,
hash_key=self.hash_key)
+ @deprecation.deprecated(_FEATURE_COLUMN_DEPRECATION_DATE,
+ _FEATURE_COLUMN_DEPRECATION)
+ def _transform_feature(self, inputs):
+ """Generates a hashed sparse cross from the input tensors."""
+ feature_tensors = []
+ for key in _collect_leaf_level_keys(self):
+ if isinstance(key, six.string_types):
+ feature_tensors.append(inputs.get(key))
+ elif isinstance(key, (CategoricalColumn, fc_old._CategoricalColumn)): # pylint: disable=protected-access
+ ids_and_weights = key._get_sparse_tensors(inputs) # pylint: disable=protected-access
+ if ids_and_weights.weight_tensor is not None:
+ raise ValueError(
+ 'crossed_column does not support weight_tensor, but the given '
+ 'column populates weight_tensor. '
+ 'Given column: {}'.format(key.name))
+ feature_tensors.append(ids_and_weights.id_tensor)
+ else:
+ raise ValueError('Unsupported column type. Given: {}'.format(key))
+ return sparse_ops.sparse_cross_hashed(
+ inputs=feature_tensors,
+ num_buckets=self.hash_bucket_size,
+ hash_key=self.hash_key)
+
@property
def num_buckets(self):
"""Returns number of buckets in this sparse feature."""
return self.hash_bucket_size
+ @property
+ @deprecation.deprecated(_FEATURE_COLUMN_DEPRECATION_DATE,
+ _FEATURE_COLUMN_DEPRECATION)
+ def _num_buckets(self):
+ return self.num_buckets
+
def get_sparse_tensors(self, transformation_cache, state_manager):
"""See `CategoricalColumn` base class."""
return CategoricalColumn.IdWeightPair(
transformation_cache.get(self, state_manager), None)
+ @deprecation.deprecated(_FEATURE_COLUMN_DEPRECATION_DATE,
+ _FEATURE_COLUMN_DEPRECATION)
+ def _get_sparse_tensors(self, inputs, weight_collections=None,
+ trainable=None):
+ """See `CategoricalColumn` base class."""
+ del weight_collections
+ del trainable
+ return CategoricalColumn.IdWeightPair(inputs.get(self), None)
+
def _collect_leaf_level_keys(cross):
"""Collects base keys by expanding all nested crosses.
@@ -3382,9 +3836,12 @@ def _prune_invalid_weights(sparse_ids, sparse_weights):
return sparse_ids, sparse_weights
-class IndicatorColumn(DenseColumn, SequenceDenseColumn,
- collections.namedtuple('IndicatorColumn',
- ('categorical_column'))):
+class IndicatorColumn(
+ DenseColumn,
+ SequenceDenseColumn,
+ fc_old._DenseColumn, # pylint: disable=protected-access
+ fc_old._SequenceDenseColumn, # pylint: disable=protected-access
+ collections.namedtuple('IndicatorColumn', ('categorical_column'))):
"""Represents a one-hot column for use in deep networks.
Args:
@@ -3393,27 +3850,16 @@ class IndicatorColumn(DenseColumn, SequenceDenseColumn,
"""
@property
+ def _is_v2_column(self):
+ return (isinstance(self.categorical_column, FeatureColumn) and
+ self.categorical_column._is_v2_column) # pylint: disable=protected-access
+
+ @property
def name(self):
"""See `FeatureColumn` base class."""
return '{}_indicator'.format(self.categorical_column.name)
- def transform_feature(self, transformation_cache, state_manager):
- """Returns dense `Tensor` representing feature.
-
- Args:
- transformation_cache: A `FeatureTransformationCache` object to access
- features.
- state_manager: A `StateManager` to create / access resources such as
- lookup tables.
-
- Returns:
- Transformed feature `Tensor`.
-
- Raises:
- ValueError: if input rank is not known at graph building time.
- """
- id_weight_pair = self.categorical_column.get_sparse_tensors(
- transformation_cache, state_manager)
+ def _transform_id_weight_pair(self, id_weight_pair):
id_tensor = id_weight_pair.id_tensor
weight_tensor = id_weight_pair.weight_tensor
@@ -3422,7 +3868,7 @@ class IndicatorColumn(DenseColumn, SequenceDenseColumn,
weighted_column = sparse_ops.sparse_merge(
sp_ids=id_tensor,
sp_values=weight_tensor,
- vocab_size=int(self.variable_shape[-1]))
+ vocab_size=int(self._variable_shape[-1]))
# Remove (?, -1) index
weighted_column = sparse_ops.sparse_slice(weighted_column, [0, 0],
weighted_column.dense_shape)
@@ -3435,22 +3881,62 @@ class IndicatorColumn(DenseColumn, SequenceDenseColumn,
# input_layer are float32.
one_hot_id_tensor = array_ops.one_hot(
dense_id_tensor,
- depth=self.variable_shape[-1],
+ depth=self._variable_shape[-1],
on_value=1.0,
off_value=0.0)
# Reduce to get a multi-hot per example.
return math_ops.reduce_sum(one_hot_id_tensor, axis=[-2])
+ def transform_feature(self, transformation_cache, state_manager):
+ """Returns dense `Tensor` representing feature.
+
+ Args:
+ transformation_cache: A `FeatureTransformationCache` object to access
+ features.
+ state_manager: A `StateManager` to create / access resources such as
+ lookup tables.
+
+ Returns:
+ Transformed feature `Tensor`.
+
+ Raises:
+ ValueError: if input rank is not known at graph building time.
+ """
+ id_weight_pair = self.categorical_column.get_sparse_tensors(
+ transformation_cache, state_manager)
+ return self._transform_id_weight_pair(id_weight_pair)
+
+ @deprecation.deprecated(_FEATURE_COLUMN_DEPRECATION_DATE,
+ _FEATURE_COLUMN_DEPRECATION)
+ def _transform_feature(self, inputs):
+ id_weight_pair = self.categorical_column._get_sparse_tensors(inputs) # pylint: disable=protected-access
+ return self._transform_id_weight_pair(id_weight_pair)
+
@property
def parse_example_spec(self):
"""See `FeatureColumn` base class."""
return self.categorical_column.parse_example_spec
@property
+ @deprecation.deprecated(_FEATURE_COLUMN_DEPRECATION_DATE,
+ _FEATURE_COLUMN_DEPRECATION)
+ def _parse_example_spec(self):
+ return self.categorical_column._parse_example_spec # pylint: disable=protected-access
+
+ @property
def variable_shape(self):
"""Returns a `TensorShape` representing the shape of the dense `Tensor`."""
- return tensor_shape.TensorShape([1, self.categorical_column.num_buckets])
+ if isinstance(self.categorical_column, FeatureColumn):
+ return tensor_shape.TensorShape([1, self.categorical_column.num_buckets])
+ else:
+ return tensor_shape.TensorShape([1, self.categorical_column._num_buckets]) # pylint: disable=protected-access
+
+ @property
+ @deprecation.deprecated(_FEATURE_COLUMN_DEPRECATION_DATE,
+ _FEATURE_COLUMN_DEPRECATION)
+ def _variable_shape(self):
+ return tensor_shape.TensorShape([1, self.categorical_column._num_buckets]) # pylint: disable=protected-access
def get_dense_tensor(self, transformation_cache, state_manager):
"""Returns dense `Tensor` representing feature.
@@ -3481,6 +3967,27 @@ class IndicatorColumn(DenseColumn, SequenceDenseColumn,
# representation created by transform_feature.
return transformation_cache.get(self, state_manager)
+ @deprecation.deprecated(_FEATURE_COLUMN_DEPRECATION_DATE,
+ _FEATURE_COLUMN_DEPRECATION)
+ def _get_dense_tensor(self, inputs, weight_collections=None, trainable=None):
+ del weight_collections
+ del trainable
+ if isinstance(
+ self.categorical_column,
+ (SequenceCategoricalColumn, fc_old._SequenceCategoricalColumn)): # pylint: disable=protected-access
+ raise ValueError(
+ 'In indicator_column: {}. '
+ 'categorical_column must not be of type _SequenceCategoricalColumn. '
+ 'Suggested fix A: If you wish to use input_layer, use a '
+ 'non-sequence categorical_column_with_*. '
+ 'Suggested fix B: If you wish to create sequence input, use '
+ 'sequence_input_layer instead of input_layer. '
+ 'Given (type {}): {}'.format(self.name, type(self.categorical_column),
+ self.categorical_column))
+ # Feature has been already transformed. Return the intermediate
+ # representation created by transform_feature.
+ return inputs.get(self)
+
def get_sequence_dense_tensor(self, transformation_cache, state_manager):
"""See `SequenceDenseColumn` base class."""
if not isinstance(self.categorical_column, SequenceCategoricalColumn):
@@ -3496,7 +4003,36 @@ class IndicatorColumn(DenseColumn, SequenceDenseColumn,
dense_tensor = transformation_cache.get(self, state_manager)
sparse_tensors = self.categorical_column.get_sparse_tensors(
transformation_cache, state_manager)
- sequence_length = _sequence_length_from_sparse_tensor(
+ sequence_length = fc_old._sequence_length_from_sparse_tensor( # pylint: disable=protected-access
+ sparse_tensors.id_tensor)
+ return SequenceDenseColumn.TensorSequenceLengthPair(
+ dense_tensor=dense_tensor, sequence_length=sequence_length)
+
+ @deprecation.deprecated(_FEATURE_COLUMN_DEPRECATION_DATE,
+ _FEATURE_COLUMN_DEPRECATION)
+ def _get_sequence_dense_tensor(self,
+ inputs,
+ weight_collections=None,
+ trainable=None):
+ # Do nothing with weight_collections and trainable since no variables are
+ # created in this function.
+ del weight_collections
+ del trainable
+ if not isinstance(
+ self.categorical_column,
+ (SequenceCategoricalColumn, fc_old._SequenceCategoricalColumn)): # pylint: disable=protected-access
+ raise ValueError(
+ 'In indicator_column: {}. '
+ 'categorical_column must be of type _SequenceCategoricalColumn '
+ 'to use sequence_input_layer. '
+ 'Suggested fix: Use one of sequence_categorical_column_with_*. '
+ 'Given (type {}): {}'.format(self.name, type(self.categorical_column),
+ self.categorical_column))
+ # Feature has been already transformed. Return the intermediate
+ # representation created by _transform_feature.
+ dense_tensor = inputs.get(self)
+ sparse_tensors = self.categorical_column._get_sparse_tensors(inputs) # pylint: disable=protected-access
+ sequence_length = fc_old._sequence_length_from_sparse_tensor( # pylint: disable=protected-access
sparse_tensors.id_tensor)
return SequenceDenseColumn.TensorSequenceLengthPair(
dense_tensor=dense_tensor, sequence_length=sequence_length)
@@ -3518,28 +4054,19 @@ def _verify_static_batch_size_equality(tensors, columns):
expected_batch_size, tensors[i].shape[0]))
-def _sequence_length_from_sparse_tensor(sp_tensor, num_elements=1):
- """Returns a [batch_size] Tensor with per-example sequence length."""
- with ops.name_scope(None, 'sequence_length') as name_scope:
- row_ids = sp_tensor.indices[:, 0]
- column_ids = sp_tensor.indices[:, 1]
- column_ids += array_ops.ones_like(column_ids)
- seq_length = math_ops.to_int64(
- math_ops.segment_max(column_ids, segment_ids=row_ids) / num_elements)
- # If the last n rows do not have ids, seq_length will have shape
- # [batch_size - n]. Pad the remaining values with zeros.
- n_pad = array_ops.shape(sp_tensor)[:1] - array_ops.shape(seq_length)[:1]
- padding = array_ops.zeros(n_pad, dtype=seq_length.dtype)
- return array_ops.concat([seq_length, padding], axis=0, name=name_scope)
-
-
-class SequenceCategoricalColumn(FeatureColumn,
- collections.namedtuple(
- 'SequenceCategoricalColumn',
- ('categorical_column'))):
+class SequenceCategoricalColumn(
+ FeatureColumn,
+ fc_old._CategoricalColumn, # pylint: disable=protected-access
+ collections.namedtuple('SequenceCategoricalColumn',
+ ('categorical_column'))):
"""Represents sequences of categorical data."""
@property
+ def _is_v2_column(self):
+ return (isinstance(self.categorical_column, FeatureColumn) and
+ self.categorical_column._is_v2_column) # pylint: disable=protected-access
+
+ @property
def name(self):
"""See `FeatureColumn` base class."""
return self.categorical_column.name
@@ -3549,16 +4076,46 @@ class SequenceCategoricalColumn(FeatureColumn,
"""See `FeatureColumn` base class."""
return self.categorical_column.parse_example_spec
+ @property
+ @deprecation.deprecated(_FEATURE_COLUMN_DEPRECATION_DATE,
+ _FEATURE_COLUMN_DEPRECATION)
+ def _parse_example_spec(self):
+ return self.categorical_column._parse_example_spec # pylint: disable=protected-access
+
def transform_feature(self, transformation_cache, state_manager):
"""See `FeatureColumn` base class."""
return self.categorical_column.transform_feature(transformation_cache,
state_manager)
+ @deprecation.deprecated(_FEATURE_COLUMN_DEPRECATION_DATE,
+ _FEATURE_COLUMN_DEPRECATION)
+ def _transform_feature(self, inputs):
+ return self.categorical_column._transform_feature(inputs) # pylint: disable=protected-access
+
@property
def num_buckets(self):
"""Returns number of buckets in this sparse feature."""
return self.categorical_column.num_buckets
+ @property
+ @deprecation.deprecated(_FEATURE_COLUMN_DEPRECATION_DATE,
+ _FEATURE_COLUMN_DEPRECATION)
+ def _num_buckets(self):
+ return self.categorical_column._num_buckets # pylint: disable=protected-access
+
+ def _get_sparse_tensors_helper(self, sparse_tensors):
+ id_tensor = sparse_tensors.id_tensor
+ weight_tensor = sparse_tensors.weight_tensor
+ # Expands third dimension, if necessary so that embeddings are not
+ # combined during embedding lookup. If the tensor is already 3D, leave
+ # as-is.
+ shape = array_ops.shape(id_tensor)
+ target_shape = [shape[0], shape[1], -1]
+ id_tensor = sparse_ops.sparse_reshape(id_tensor, target_shape)
+ if weight_tensor is not None:
+ weight_tensor = sparse_ops.sparse_reshape(weight_tensor, target_shape)
+ return CategoricalColumn.IdWeightPair(id_tensor, weight_tensor)
+
def get_sequence_sparse_tensors(self, transformation_cache, state_manager):
"""Returns an IdWeightPair.
@@ -3580,27 +4137,11 @@ class SequenceCategoricalColumn(FeatureColumn,
"""
sparse_tensors = self.categorical_column.get_sparse_tensors(
transformation_cache, state_manager)
- id_tensor = sparse_tensors.id_tensor
- weight_tensor = sparse_tensors.weight_tensor
- # Expands final dimension, so that embeddings are not combined during
- # embedding lookup.
- check_id_rank = check_ops.assert_equal(
- array_ops.rank(id_tensor), 2,
- data=[
- 'Column {} expected ID tensor of rank 2. '.format(self.name),
- 'id_tensor shape: ', array_ops.shape(id_tensor)])
- with ops.control_dependencies([check_id_rank]):
- id_tensor = sparse_ops.sparse_reshape(
- id_tensor,
- shape=array_ops.concat([id_tensor.dense_shape, [1]], axis=0))
- if weight_tensor is not None:
- check_weight_rank = check_ops.assert_equal(
- array_ops.rank(weight_tensor), 2,
- data=[
- 'Column {} expected weight tensor of rank 2.'.format(self.name),
- 'weight_tensor shape:', array_ops.shape(weight_tensor)])
- with ops.control_dependencies([check_weight_rank]):
- weight_tensor = sparse_ops.sparse_reshape(
- weight_tensor,
- shape=array_ops.concat([weight_tensor.dense_shape, [1]], axis=0))
- return CategoricalColumn.IdWeightPair(id_tensor, weight_tensor)
+ return self._get_sparse_tensors_helper(sparse_tensors)
+
+ @deprecation.deprecated(_FEATURE_COLUMN_DEPRECATION_DATE,
+ _FEATURE_COLUMN_DEPRECATION)
+ def _get_sparse_tensors(self, inputs, weight_collections=None,
+ trainable=None):
+ sparse_tensors = self.categorical_column._get_sparse_tensors(inputs) # pylint: disable=protected-access
+ return self._get_sparse_tensors_helper(sparse_tensors)