diff options
Diffstat (limited to 'tensorflow/python/feature_column/feature_column_v2.py')
-rw-r--r-- | tensorflow/python/feature_column/feature_column_v2.py | 869 |
1 files changed, 705 insertions, 164 deletions
diff --git a/tensorflow/python/feature_column/feature_column_v2.py b/tensorflow/python/feature_column/feature_column_v2.py index b79373c475..6d089de991 100644 --- a/tensorflow/python/feature_column/feature_column_v2.py +++ b/tensorflow/python/feature_column/feature_column_v2.py @@ -136,6 +136,7 @@ import six from tensorflow.python.eager import context +from tensorflow.python.feature_column import feature_column as fc_old from tensorflow.python.framework import dtypes from tensorflow.python.framework import ops from tensorflow.python.framework import sparse_tensor as sparse_tensor_lib @@ -157,9 +158,16 @@ from tensorflow.python.ops import variables from tensorflow.python.platform import gfile from tensorflow.python.platform import tf_logging as logging from tensorflow.python.training import checkpoint_utils +from tensorflow.python.util import deprecation from tensorflow.python.util import nest +_FEATURE_COLUMN_DEPRECATION_DATE = '2018-11-30' +_FEATURE_COLUMN_DEPRECATION = ('The old _FeatureColumn APIs are being ' + 'deprecated. Please use the new FeatureColumn ' + 'APIs instead.') + + class StateManager(object): """Manages the state associated with FeatureColumns. @@ -440,10 +448,6 @@ class FeatureLayer(Layer): return (input_shape[0], total_elements) -def _strip_leading_slashes(name): - return name.rsplit('/', 1)[-1] - - class LinearModel(Layer): """Produces a linear prediction `Tensor` based on given `feature_columns`. @@ -775,12 +779,12 @@ def embedding_column( categorical_column, dimension, combiner='mean', initializer=None, ckpt_to_load_from=None, tensor_name_in_ckpt=None, max_norm=None, trainable=True): - """`_DenseColumn` that converts from sparse, categorical input. + """`DenseColumn` that converts from sparse, categorical input. Use this when your inputs are sparse, but you want to convert them to a dense representation (e.g., to feed to a DNN). - Inputs must be a `_CategoricalColumn` created by any of the + Inputs must be a `CategoricalColumn` created by any of the `categorical_column_*` function. Here is an example of using `embedding_column` with `DNNClassifier`: @@ -814,12 +818,12 @@ def embedding_column( ``` Args: - categorical_column: A `_CategoricalColumn` created by a + categorical_column: A `CategoricalColumn` created by a `categorical_column_with_*` function. This column produces the sparse IDs that are inputs to the embedding lookup. dimension: An integer specifying dimension of the embedding, must be > 0. - combiner: A string specifying how to reduce if there are multiple entries - in a single row. Currently 'mean', 'sqrtn' and 'sum' are supported, with + combiner: A string specifying how to reduce if there are multiple entries in + a single row. Currently 'mean', 'sqrtn' and 'sum' are supported, with 'mean' the default. 'sqrtn' often achieves good accuracy, in particular with bag-of-words columns. Each of this can be thought as example level normalizations on the column. For more information, see @@ -830,14 +834,14 @@ def embedding_column( `1/sqrt(dimension)`. ckpt_to_load_from: String representing checkpoint name/pattern from which to restore column weights. Required if `tensor_name_in_ckpt` is not `None`. - tensor_name_in_ckpt: Name of the `Tensor` in `ckpt_to_load_from` from - which to restore the column weights. Required if `ckpt_to_load_from` is - not `None`. + tensor_name_in_ckpt: Name of the `Tensor` in `ckpt_to_load_from` from which + to restore the column weights. Required if `ckpt_to_load_from` is not + `None`. max_norm: If not `None`, embedding values are l2-normalized to this value. trainable: Whether or not the embedding is trainable. Default is True. Returns: - `_DenseColumn` that converts from sparse input. + `DenseColumn` that converts from sparse input. Raises: ValueError: if `dimension` not > 0. @@ -1181,7 +1185,7 @@ def bucketized_column(source_column, boundaries): one-dimensional. ValueError: If `boundaries` is not a sorted list or tuple. """ - if not isinstance(source_column, NumericColumn): + if not isinstance(source_column, (NumericColumn, fc_old._NumericColumn)): # pylint: disable=protected-access raise ValueError( 'source_column must be a column generated with numeric_column(). ' 'Given: {}'.format(source_column)) @@ -1390,7 +1394,7 @@ def categorical_column_with_vocabulary_file(key, def categorical_column_with_vocabulary_list( key, vocabulary_list, dtype=None, default_value=-1, num_oov_buckets=0): - """A `_CategoricalColumn` with in-memory vocabulary. + """A `CategoricalColumn` with in-memory vocabulary. Use this when your inputs are in string or integer format, and you have an in-memory vocabulary mapping each value to an integer ID. By default, @@ -1439,14 +1443,14 @@ def categorical_column_with_vocabulary_list( ``` Args: - key: A unique string identifying the input feature. It is used as the - column name and the dictionary key for feature parsing configs, feature - `Tensor` objects, and feature columns. + key: A unique string identifying the input feature. It is used as the column + name and the dictionary key for feature parsing configs, feature `Tensor` + objects, and feature columns. vocabulary_list: An ordered iterable defining the vocabulary. Each feature is mapped to the index of its value (if present) in `vocabulary_list`. Must be castable to `dtype`. - dtype: The type of features. Only string and integer types are supported. - If `None`, it will be inferred from `vocabulary_list`. + dtype: The type of features. Only string and integer types are supported. If + `None`, it will be inferred from `vocabulary_list`. default_value: The integer ID value to return for out-of-vocabulary feature values, defaults to `-1`. This can not be specified with a positive `num_oov_buckets`. @@ -1604,7 +1608,7 @@ def indicator_column(categorical_column): def weighted_categorical_column( categorical_column, weight_feature_key, dtype=dtypes.float32): - """Applies weight values to a `_CategoricalColumn`. + """Applies weight values to a `CategoricalColumn`. Use this when each of your sparse inputs has both an ID and a value. For example, if you're representing text documents as a collection of word @@ -1655,7 +1659,7 @@ def weighted_categorical_column( the same indices and dense shape. Args: - categorical_column: A `_CategoricalColumn` created by + categorical_column: A `CategoricalColumn` created by `categorical_column_with_*` functions. weight_feature_key: String key for weight values. dtype: Type of weights, such as `tf.float32`. Only float and integer weights @@ -1788,12 +1792,13 @@ def crossed_column(keys, hash_bucket_size, hash_key=None): 'keys must be a list with length > 1. Given: {}'.format(keys)) for key in keys: if (not isinstance(key, six.string_types) and - not isinstance(key, CategoricalColumn)): + not isinstance(key, (CategoricalColumn, fc_old._CategoricalColumn))): # pylint: disable=protected-access raise ValueError( 'Unsupported key type. All keys must be either string, or ' 'categorical column except HashedCategoricalColumn. ' 'Given: {}'.format(key)) - if isinstance(key, HashedCategoricalColumn): + if isinstance(key, + (HashedCategoricalColumn, fc_old._HashedCategoricalColumn)): # pylint: disable=protected-access raise ValueError( 'categorical_column_with_hash_bucket is not supported for crossing. ' 'Hashing before crossing will increase probability of collision. ' @@ -1882,6 +1887,16 @@ class FeatureColumn(object): """ pass + @abc.abstractproperty + def _is_v2_column(self): + """Returns whether this FeatureColumn is fully conformant to the new API. + + This is needed for composition type cases where an EmbeddingColumn etc. + might take in old categorical columns as input and then we want to use the + old API. + """ + pass + class DenseColumn(FeatureColumn): """Represents a column which can be represented as `Tensor`. @@ -1927,6 +1942,8 @@ def is_feature_column_v2(feature_columns): for feature_column in feature_columns: if not isinstance(feature_column, FeatureColumn): return False + if not feature_column._is_v2_column: # pylint: disable=protected-access + return False return True @@ -2202,19 +2219,6 @@ class FeatureTransformationCache(object): # TODO(ptucker): Move to third_party/tensorflow/python/ops/sparse_ops.py -def _shape_offsets(shape): - """Returns moving offset for each dimension given shape.""" - offsets = [] - for dim in reversed(shape): - if offsets: - offsets.append(dim * offsets[-1]) - else: - offsets.append(dim) - offsets.reverse() - return offsets - - -# TODO(ptucker): Move to third_party/tensorflow/python/ops/sparse_ops.py def _to_sparse_input_and_drop_ignore_values(input_tensor, ignore_value=None): """Converts a `Tensor` to a `SparseTensor`, dropping ignore_value cells. @@ -2306,12 +2310,17 @@ def _normalize_feature_columns(feature_columns): class NumericColumn( DenseColumn, + fc_old._DenseColumn, # pylint: disable=protected-access collections.namedtuple( 'NumericColumn', ('key', 'shape', 'default_value', 'dtype', 'normalizer_fn'))): """see `numeric_column`.""" @property + def _is_v2_column(self): + return True + + @property def name(self): """See `FeatureColumn` base class.""" return self.key @@ -2325,6 +2334,27 @@ class NumericColumn( self.default_value) } + @property + @deprecation.deprecated(_FEATURE_COLUMN_DEPRECATION_DATE, + _FEATURE_COLUMN_DEPRECATION) + def _parse_example_spec(self): + return self.parse_example_spec + + def _transform_input_tensor(self, input_tensor): + if isinstance(input_tensor, sparse_tensor_lib.SparseTensor): + raise ValueError( + 'The corresponding Tensor of numerical column must be a Tensor. ' + 'SparseTensor is not supported. key: {}'.format(self.key)) + if self.normalizer_fn is not None: + input_tensor = self.normalizer_fn(input_tensor) + return math_ops.to_float(input_tensor) + + @deprecation.deprecated(_FEATURE_COLUMN_DEPRECATION_DATE, + _FEATURE_COLUMN_DEPRECATION) + def _transform_feature(self, inputs): + input_tensor = inputs.get(self.key) + return self._transform_input_tensor(input_tensor) + def transform_feature(self, transformation_cache, state_manager): """See `FeatureColumn` base class. @@ -2342,19 +2372,19 @@ class NumericColumn( ValueError: If a SparseTensor is passed in. """ input_tensor = transformation_cache.get(self.key, state_manager) - if isinstance(input_tensor, sparse_tensor_lib.SparseTensor): - raise ValueError( - 'The corresponding Tensor of numerical column must be a Tensor. ' - 'SparseTensor is not supported. key: {}'.format(self.key)) - if self.normalizer_fn is not None: - input_tensor = self.normalizer_fn(input_tensor) - return math_ops.to_float(input_tensor) + return self._transform_input_tensor(input_tensor) @property def variable_shape(self): """See `DenseColumn` base class.""" return tensor_shape.TensorShape(self.shape) + @property + @deprecation.deprecated(_FEATURE_COLUMN_DEPRECATION_DATE, + _FEATURE_COLUMN_DEPRECATION) + def _variable_shape(self): + return self.variable_shape + def get_dense_tensor(self, transformation_cache, state_manager): """Returns dense `Tensor` representing numeric feature. @@ -2371,13 +2401,29 @@ class NumericColumn( # representation created by _transform_feature. return transformation_cache.get(self, state_manager) + @deprecation.deprecated(_FEATURE_COLUMN_DEPRECATION_DATE, + _FEATURE_COLUMN_DEPRECATION) + def _get_dense_tensor(self, inputs, weight_collections=None, trainable=None): + del weight_collections + del trainable + return inputs.get(self) + -class BucketizedColumn(DenseColumn, CategoricalColumn, - collections.namedtuple('BucketizedColumn', - ('source_column', 'boundaries'))): +class BucketizedColumn( + DenseColumn, + CategoricalColumn, + fc_old._DenseColumn, # pylint: disable=protected-access + fc_old._CategoricalColumn, # pylint: disable=protected-access + collections.namedtuple('BucketizedColumn', + ('source_column', 'boundaries'))): """See `bucketized_column`.""" @property + def _is_v2_column(self): + return (isinstance(self.source_column, FeatureColumn) and + self.source_column._is_v2_column) # pylint: disable=protected-access + + @property def name(self): """See `FeatureColumn` base class.""" return '{}_bucketized'.format(self.source_column.name) @@ -2387,6 +2433,21 @@ class BucketizedColumn(DenseColumn, CategoricalColumn, """See `FeatureColumn` base class.""" return self.source_column.parse_example_spec + @property + @deprecation.deprecated(_FEATURE_COLUMN_DEPRECATION_DATE, + _FEATURE_COLUMN_DEPRECATION) + def _parse_example_spec(self): + return self.source_column._parse_example_spec # pylint: disable=protected-access + + @deprecation.deprecated(_FEATURE_COLUMN_DEPRECATION_DATE, + _FEATURE_COLUMN_DEPRECATION) + def _transform_feature(self, inputs): + """Returns bucketized categorical `source_column` tensor.""" + source_tensor = inputs.get(self.source_column) + return math_ops._bucketize( # pylint: disable=protected-access + source_tensor, + boundaries=self.boundaries) + def transform_feature(self, transformation_cache, state_manager): """Returns bucketized categorical `source_column` tensor.""" source_tensor = transformation_cache.get(self.source_column, state_manager) @@ -2400,24 +2461,45 @@ class BucketizedColumn(DenseColumn, CategoricalColumn, return tensor_shape.TensorShape( tuple(self.source_column.shape) + (len(self.boundaries) + 1,)) - def get_dense_tensor(self, transformation_cache, state_manager): - """Returns one hot encoded dense `Tensor`.""" - input_tensor = transformation_cache.get(self, state_manager) + @property + @deprecation.deprecated(_FEATURE_COLUMN_DEPRECATION_DATE, + _FEATURE_COLUMN_DEPRECATION) + def _variable_shape(self): + return self.variable_shape + + def _get_dense_tensor_for_input_tensor(self, input_tensor): return array_ops.one_hot( indices=math_ops.to_int64(input_tensor), depth=len(self.boundaries) + 1, on_value=1., off_value=0.) + def get_dense_tensor(self, transformation_cache, state_manager): + """Returns one hot encoded dense `Tensor`.""" + input_tensor = transformation_cache.get(self, state_manager) + return self._get_dense_tensor_for_input_tensor(input_tensor) + + @deprecation.deprecated(_FEATURE_COLUMN_DEPRECATION_DATE, + _FEATURE_COLUMN_DEPRECATION) + def _get_dense_tensor(self, inputs, weight_collections=None, trainable=None): + del weight_collections + del trainable + input_tensor = inputs.get(self) + return self._get_dense_tensor_for_input_tensor(input_tensor) + @property def num_buckets(self): """See `CategoricalColumn` base class.""" # By construction, source_column is always one-dimensional. return (len(self.boundaries) + 1) * self.source_column.shape[0] - def get_sparse_tensors(self, transformation_cache, state_manager): - """Converts dense inputs to SparseTensor so downstream code can use it.""" - input_tensor = transformation_cache.get(self, state_manager) + @property + @deprecation.deprecated(_FEATURE_COLUMN_DEPRECATION_DATE, + _FEATURE_COLUMN_DEPRECATION) + def _num_buckets(self): + return self.num_buckets + + def _get_sparse_tensors_for_input_tensor(self, input_tensor): batch_size = array_ops.shape(input_tensor)[0] # By construction, source_column is always one-dimensional. source_dimension = self.source_column.shape[0] @@ -2443,9 +2525,27 @@ class BucketizedColumn(DenseColumn, CategoricalColumn, dense_shape=dense_shape) return CategoricalColumn.IdWeightPair(sparse_tensor, None) + def get_sparse_tensors(self, transformation_cache, state_manager): + """Converts dense inputs to SparseTensor so downstream code can use it.""" + input_tensor = transformation_cache.get(self, state_manager) + return self._get_sparse_tensors_for_input_tensor(input_tensor) + + @deprecation.deprecated(_FEATURE_COLUMN_DEPRECATION_DATE, + _FEATURE_COLUMN_DEPRECATION) + def _get_sparse_tensors(self, inputs, weight_collections=None, + trainable=None): + """Converts dense inputs to SparseTensor so downstream code can use it.""" + del weight_collections + del trainable + input_tensor = inputs.get(self) + return self._get_sparse_tensors_for_input_tensor(input_tensor) + class EmbeddingColumn( - DenseColumn, SequenceDenseColumn, + DenseColumn, + SequenceDenseColumn, + fc_old._DenseColumn, # pylint: disable=protected-access + fc_old._SequenceDenseColumn, # pylint: disable=protected-access collections.namedtuple( 'EmbeddingColumn', ('categorical_column', 'dimension', 'combiner', 'initializer', @@ -2453,6 +2553,11 @@ class EmbeddingColumn( """See `embedding_column`.""" @property + def _is_v2_column(self): + return (isinstance(self.categorical_column, FeatureColumn) and + self.categorical_column._is_v2_column) # pylint: disable=protected-access + + @property def name(self): """See `FeatureColumn` base class.""" return '{}_embedding'.format(self.categorical_column.name) @@ -2462,18 +2567,35 @@ class EmbeddingColumn( """See `FeatureColumn` base class.""" return self.categorical_column.parse_example_spec + @property + @deprecation.deprecated(_FEATURE_COLUMN_DEPRECATION_DATE, + _FEATURE_COLUMN_DEPRECATION) + def _parse_example_spec(self): + return self.categorical_column._parse_example_spec # pylint: disable=protected-access + def transform_feature(self, transformation_cache, state_manager): """Transforms underlying `categorical_column`.""" return transformation_cache.get(self.categorical_column, state_manager) + @deprecation.deprecated(_FEATURE_COLUMN_DEPRECATION_DATE, + _FEATURE_COLUMN_DEPRECATION) + def _transform_feature(self, inputs): + return inputs.get(self.categorical_column) + @property def variable_shape(self): """See `DenseColumn` base class.""" return tensor_shape.vector(self.dimension) + @property + @deprecation.deprecated(_FEATURE_COLUMN_DEPRECATION_DATE, + _FEATURE_COLUMN_DEPRECATION) + def _variable_shape(self): + return self.variable_shape + def create_state(self, state_manager): """Creates the embedding lookup variable.""" - embedding_shape = (self.categorical_column.num_buckets, self.dimension) + embedding_shape = (self.categorical_column._num_buckets, self.dimension) # pylint: disable=protected-access state_manager.create_variable( self, name='embedding_weights', @@ -2482,17 +2604,11 @@ class EmbeddingColumn( trainable=self.trainable, initializer=self.initializer) - def _get_dense_tensor_internal(self, transformation_cache, state_manager): - """Private method that follows the signature of _get_dense_tensor.""" - # Get sparse IDs and weights. - sparse_tensors = self.categorical_column.get_sparse_tensors( - transformation_cache, state_manager) + def _get_dense_tensor_internal_helper(self, sparse_tensors, + embedding_weights): sparse_ids = sparse_tensors.id_tensor sparse_weights = sparse_tensors.weight_tensor - embedding_weights = state_manager.get_variable( - self, name='embedding_weights') - if self.ckpt_to_load_from is not None: to_restore = embedding_weights if isinstance(to_restore, variables.PartitionedVariable): @@ -2510,6 +2626,30 @@ class EmbeddingColumn( name='%s_weights' % self.name, max_norm=self.max_norm) + def _get_dense_tensor_internal(self, sparse_tensors, state_manager): + """Private method that follows the signature of get_dense_tensor.""" + embedding_weights = state_manager.get_variable( + self, name='embedding_weights') + return self._get_dense_tensor_internal_helper(sparse_tensors, + embedding_weights) + + def _old_get_dense_tensor_internal(self, sparse_tensors, weight_collections, + trainable): + """Private method that follows the signature of _get_dense_tensor.""" + embedding_shape = (self.categorical_column._num_buckets, self.dimension) # pylint: disable=protected-access + if (weight_collections and + ops.GraphKeys.GLOBAL_VARIABLES not in weight_collections): + weight_collections.append(ops.GraphKeys.GLOBAL_VARIABLES) + embedding_weights = variable_scope.get_variable( + name='embedding_weights', + shape=embedding_shape, + dtype=dtypes.float32, + initializer=self.initializer, + trainable=self.trainable and trainable, + collections=weight_collections) + return self._get_dense_tensor_internal_helper(sparse_tensors, + embedding_weights) + def get_dense_tensor(self, transformation_cache, state_manager): """Returns tensor after doing the embedding lookup. @@ -2535,7 +2675,30 @@ class EmbeddingColumn( 'sequence_input_layer instead of input_layer. ' 'Given (type {}): {}'.format(self.name, type(self.categorical_column), self.categorical_column)) - return self._get_dense_tensor_internal(transformation_cache, state_manager) + # Get sparse IDs and weights. + sparse_tensors = self.categorical_column.get_sparse_tensors( + transformation_cache, state_manager) + return self._get_dense_tensor_internal(sparse_tensors, state_manager) + + @deprecation.deprecated(_FEATURE_COLUMN_DEPRECATION_DATE, + _FEATURE_COLUMN_DEPRECATION) + def _get_dense_tensor(self, inputs, weight_collections=None, trainable=None): + if isinstance( + self.categorical_column, + (SequenceCategoricalColumn, fc_old._SequenceCategoricalColumn)): # pylint: disable=protected-access + raise ValueError( + 'In embedding_column: {}. ' + 'categorical_column must not be of type _SequenceCategoricalColumn. ' + 'Suggested fix A: If you wish to use input_layer, use a ' + 'non-sequence categorical_column_with_*. ' + 'Suggested fix B: If you wish to create sequence input, use ' + 'sequence_input_layer instead of input_layer. ' + 'Given (type {}): {}'.format(self.name, type(self.categorical_column), + self.categorical_column)) + sparse_tensors = self.categorical_column._get_sparse_tensors( # pylint: disable=protected-access + inputs, weight_collections, trainable) + return self._old_get_dense_tensor_internal(sparse_tensors, + weight_collections, trainable) def get_sequence_dense_tensor(self, transformation_cache, state_manager): """See `SequenceDenseColumn` base class.""" @@ -2547,21 +2710,40 @@ class EmbeddingColumn( 'Suggested fix: Use one of sequence_categorical_column_with_*. ' 'Given (type {}): {}'.format(self.name, type(self.categorical_column), self.categorical_column)) - dense_tensor = self._get_dense_tensor_internal( # pylint: disable=protected-access + sparse_tensors = self.categorical_column.get_sequence_sparse_tensors( transformation_cache, state_manager) - sparse_tensors = self.categorical_column.get_sparse_tensors( - transformation_cache, state_manager) - sequence_length = _sequence_length_from_sparse_tensor( + dense_tensor = self._get_dense_tensor_internal(sparse_tensors, + state_manager) + sequence_length = fc_old._sequence_length_from_sparse_tensor( # pylint: disable=protected-access sparse_tensors.id_tensor) return SequenceDenseColumn.TensorSequenceLengthPair( dense_tensor=dense_tensor, sequence_length=sequence_length) - -def _get_graph_for_variable(var): - if isinstance(var, variables.PartitionedVariable): - return list(var)[0].graph - else: - return var.graph + @deprecation.deprecated(_FEATURE_COLUMN_DEPRECATION_DATE, + _FEATURE_COLUMN_DEPRECATION) + def _get_sequence_dense_tensor(self, + inputs, + weight_collections=None, + trainable=None): + if not isinstance( + self.categorical_column, + (SequenceCategoricalColumn, fc_old._SequenceCategoricalColumn)): # pylint: disable=protected-access + raise ValueError( + 'In embedding_column: {}. ' + 'categorical_column must be of type _SequenceCategoricalColumn ' + 'to use sequence_input_layer. ' + 'Suggested fix: Use one of sequence_categorical_column_with_*. ' + 'Given (type {}): {}'.format(self.name, type(self.categorical_column), + self.categorical_column)) + sparse_tensors = self.categorical_column._get_sparse_tensors(inputs) # pylint: disable=protected-access + dense_tensor = self._old_get_dense_tensor_internal( + sparse_tensors, + weight_collections=weight_collections, + trainable=trainable) + sequence_length = fc_old._sequence_length_from_sparse_tensor( # pylint: disable=protected-access + sparse_tensors.id_tensor) + return SequenceDenseColumn.TensorSequenceLengthPair( + dense_tensor=dense_tensor, sequence_length=sequence_length) class SharedEmbeddingStateManager(Layer): @@ -2633,8 +2815,17 @@ def maybe_create_shared_state_manager(feature_columns): return None +def _raise_shared_embedding_column_error(): + raise ValueError('SharedEmbeddingColumns are not supported in ' + '`linear_model` or `input_layer`. Please use ' + '`FeatureLayer` or `LinearModel` instead.') + + class SharedEmbeddingColumn( - DenseColumn, SequenceDenseColumn, + DenseColumn, + SequenceDenseColumn, + fc_old._DenseColumn, # pylint: disable=protected-access + fc_old._SequenceDenseColumn, # pylint: disable=protected-access collections.namedtuple( 'SharedEmbeddingColumn', ('categorical_column', 'dimension', 'combiner', 'initializer', @@ -2643,6 +2834,10 @@ class SharedEmbeddingColumn( """See `embedding_column`.""" @property + def _is_v2_column(self): + return True + + @property def name(self): """See `FeatureColumn` base class.""" return '{}_shared_embedding'.format(self.categorical_column.name) @@ -2662,15 +2857,26 @@ class SharedEmbeddingColumn( """See `FeatureColumn` base class.""" return self.categorical_column.parse_example_spec + @property + def _parse_example_spec(self): + return _raise_shared_embedding_column_error() + def transform_feature(self, transformation_cache, state_manager): """See `FeatureColumn` base class.""" return transformation_cache.get(self.categorical_column, state_manager) + def _transform_feature(self, inputs): + return _raise_shared_embedding_column_error() + @property def variable_shape(self): """See `DenseColumn` base class.""" return tensor_shape.vector(self.dimension) + @property + def _variable_shape(self): + return _raise_shared_embedding_column_error() + def create_state(self, state_manager): """Creates the shared embedding lookup variable.""" if not isinstance(state_manager, SharedEmbeddingStateManager): @@ -2731,6 +2937,9 @@ class SharedEmbeddingColumn( self.categorical_column)) return self._get_dense_tensor_internal(transformation_cache, state_manager) + def _get_dense_tensor(self, inputs, weight_collections=None, trainable=None): + return _raise_shared_embedding_column_error() + def get_sequence_dense_tensor(self, transformation_cache, state_manager): """See `SequenceDenseColumn` base class.""" if not isinstance(self.categorical_column, SequenceCategoricalColumn): @@ -2745,11 +2954,17 @@ class SharedEmbeddingColumn( state_manager) sparse_tensors = self.categorical_column.get_sparse_tensors( transformation_cache, state_manager) - sequence_length = _sequence_length_from_sparse_tensor( + sequence_length = fc_old._sequence_length_from_sparse_tensor( # pylint: disable=protected-access sparse_tensors.id_tensor) return SequenceDenseColumn.TensorSequenceLengthPair( dense_tensor=dense_tensor, sequence_length=sequence_length) + def _get_sequence_dense_tensor(self, + inputs, + weight_collections=None, + trainable=None): + return _raise_shared_embedding_column_error() + def _create_tuple(shape, value): """Returns a tuple with given shape and filled with value.""" @@ -2858,11 +3073,16 @@ def _check_default_value(shape, default_value, dtype, key): class HashedCategoricalColumn( CategoricalColumn, + fc_old._CategoricalColumn, # pylint: disable=protected-access collections.namedtuple('HashedCategoricalColumn', ('key', 'hash_bucket_size', 'dtype'))): """see `categorical_column_with_hash_bucket`.""" @property + def _is_v2_column(self): + return True + + @property def name(self): """See `FeatureColumn` base class.""" return self.key @@ -2872,10 +3092,14 @@ class HashedCategoricalColumn( """See `FeatureColumn` base class.""" return {self.key: parsing_ops.VarLenFeature(self.dtype)} - def transform_feature(self, transformation_cache, state_manager): + @property + @deprecation.deprecated(_FEATURE_COLUMN_DEPRECATION_DATE, + _FEATURE_COLUMN_DEPRECATION) + def _parse_example_spec(self): + return self.parse_example_spec + + def _transform_input_tensor(self, input_tensor): """Hashes the values in the feature_column.""" - input_tensor = _to_sparse_input_and_drop_ignore_values( - transformation_cache.get(self.key, state_manager)) if not isinstance(input_tensor, sparse_tensor_lib.SparseTensor): raise ValueError('SparseColumn input must be a SparseTensor.') @@ -2899,25 +3123,56 @@ class HashedCategoricalColumn( return sparse_tensor_lib.SparseTensor( input_tensor.indices, sparse_id_values, input_tensor.dense_shape) + def transform_feature(self, transformation_cache, state_manager): + """Hashes the values in the feature_column.""" + input_tensor = _to_sparse_input_and_drop_ignore_values( + transformation_cache.get(self.key, state_manager)) + return self._transform_input_tensor(input_tensor) + + @deprecation.deprecated(_FEATURE_COLUMN_DEPRECATION_DATE, + _FEATURE_COLUMN_DEPRECATION) + def _transform_feature(self, inputs): + input_tensor = _to_sparse_input_and_drop_ignore_values(inputs.get(self.key)) + return self._transform_input_tensor(input_tensor) + @property def num_buckets(self): """Returns number of buckets in this sparse feature.""" return self.hash_bucket_size + @property + @deprecation.deprecated(_FEATURE_COLUMN_DEPRECATION_DATE, + _FEATURE_COLUMN_DEPRECATION) + def _num_buckets(self): + return self.num_buckets + def get_sparse_tensors(self, transformation_cache, state_manager): """See `CategoricalColumn` base class.""" return CategoricalColumn.IdWeightPair( transformation_cache.get(self, state_manager), None) + @deprecation.deprecated(_FEATURE_COLUMN_DEPRECATION_DATE, + _FEATURE_COLUMN_DEPRECATION) + def _get_sparse_tensors(self, inputs, weight_collections=None, + trainable=None): + del weight_collections + del trainable + return CategoricalColumn.IdWeightPair(inputs.get(self), None) + class VocabularyFileCategoricalColumn( CategoricalColumn, + fc_old._CategoricalColumn, # pylint: disable=protected-access collections.namedtuple('VocabularyFileCategoricalColumn', ('key', 'vocabulary_file', 'vocabulary_size', 'num_oov_buckets', 'dtype', 'default_value'))): """See `categorical_column_with_vocabulary_file`.""" @property + def _is_v2_column(self): + return True + + @property def name(self): """See `FeatureColumn` base class.""" return self.key @@ -2927,11 +3182,14 @@ class VocabularyFileCategoricalColumn( """See `FeatureColumn` base class.""" return {self.key: parsing_ops.VarLenFeature(self.dtype)} - def transform_feature(self, transformation_cache, state_manager): - """Creates a lookup table for the vocabulary.""" - input_tensor = _to_sparse_input_and_drop_ignore_values( - transformation_cache.get(self.key, state_manager)) + @property + @deprecation.deprecated(_FEATURE_COLUMN_DEPRECATION_DATE, + _FEATURE_COLUMN_DEPRECATION) + def _parse_example_spec(self): + return self.parse_example_spec + def _transform_input_tensor(self, input_tensor): + """Creates a lookup table for the vocabulary.""" if self.dtype.is_integer != input_tensor.dtype.is_integer: raise ValueError( 'Column dtype and SparseTensors dtype must be compatible. ' @@ -2957,19 +3215,46 @@ class VocabularyFileCategoricalColumn( key_dtype=key_dtype, name='{}_lookup'.format(self.key)).lookup(input_tensor) + def transform_feature(self, transformation_cache, state_manager): + """Creates a lookup table for the vocabulary.""" + input_tensor = _to_sparse_input_and_drop_ignore_values( + transformation_cache.get(self.key, state_manager)) + return self._transform_input_tensor(input_tensor) + + @deprecation.deprecated(_FEATURE_COLUMN_DEPRECATION_DATE, + _FEATURE_COLUMN_DEPRECATION) + def _transform_feature(self, inputs): + input_tensor = _to_sparse_input_and_drop_ignore_values(inputs.get(self.key)) + return self._transform_input_tensor(input_tensor) + @property def num_buckets(self): """Returns number of buckets in this sparse feature.""" return self.vocabulary_size + self.num_oov_buckets + @property + @deprecation.deprecated(_FEATURE_COLUMN_DEPRECATION_DATE, + _FEATURE_COLUMN_DEPRECATION) + def _num_buckets(self): + return self.num_buckets + def get_sparse_tensors(self, transformation_cache, state_manager): """See `CategoricalColumn` base class.""" return CategoricalColumn.IdWeightPair( transformation_cache.get(self, state_manager), None) + @deprecation.deprecated(_FEATURE_COLUMN_DEPRECATION_DATE, + _FEATURE_COLUMN_DEPRECATION) + def _get_sparse_tensors(self, inputs, weight_collections=None, + trainable=None): + del weight_collections + del trainable + return CategoricalColumn.IdWeightPair(inputs.get(self), None) + class VocabularyListCategoricalColumn( CategoricalColumn, + fc_old._CategoricalColumn, # pylint: disable=protected-access collections.namedtuple( 'VocabularyListCategoricalColumn', ('key', 'vocabulary_list', 'dtype', 'default_value', 'num_oov_buckets')) @@ -2977,6 +3262,10 @@ class VocabularyListCategoricalColumn( """See `categorical_column_with_vocabulary_list`.""" @property + def _is_v2_column(self): + return True + + @property def name(self): """See `FeatureColumn` base class.""" return self.key @@ -2986,11 +3275,14 @@ class VocabularyListCategoricalColumn( """See `FeatureColumn` base class.""" return {self.key: parsing_ops.VarLenFeature(self.dtype)} - def transform_feature(self, transformation_cache, state_manager): - """Creates a lookup table for the vocabulary list.""" - input_tensor = _to_sparse_input_and_drop_ignore_values( - transformation_cache.get(self.key, state_manager)) + @property + @deprecation.deprecated(_FEATURE_COLUMN_DEPRECATION_DATE, + _FEATURE_COLUMN_DEPRECATION) + def _parse_example_spec(self): + return self.parse_example_spec + def _transform_input_tensor(self, input_tensor): + """Creates a lookup table for the vocabulary list.""" if self.dtype.is_integer != input_tensor.dtype.is_integer: raise ValueError( 'Column dtype and SparseTensors dtype must be compatible. ' @@ -3015,25 +3307,56 @@ class VocabularyListCategoricalColumn( dtype=key_dtype, name='{}_lookup'.format(self.key)).lookup(input_tensor) + def transform_feature(self, transformation_cache, state_manager): + """Creates a lookup table for the vocabulary list.""" + input_tensor = _to_sparse_input_and_drop_ignore_values( + transformation_cache.get(self.key, state_manager)) + return self._transform_input_tensor(input_tensor) + + @deprecation.deprecated(_FEATURE_COLUMN_DEPRECATION_DATE, + _FEATURE_COLUMN_DEPRECATION) + def _transform_feature(self, inputs): + input_tensor = _to_sparse_input_and_drop_ignore_values(inputs.get(self.key)) + return self._transform_input_tensor(input_tensor) + @property def num_buckets(self): """Returns number of buckets in this sparse feature.""" return len(self.vocabulary_list) + self.num_oov_buckets + @property + @deprecation.deprecated(_FEATURE_COLUMN_DEPRECATION_DATE, + _FEATURE_COLUMN_DEPRECATION) + def _num_buckets(self): + return self.num_buckets + def get_sparse_tensors(self, transformation_cache, state_manager): """See `CategoricalColumn` base class.""" return CategoricalColumn.IdWeightPair( transformation_cache.get(self, state_manager), None) + @deprecation.deprecated(_FEATURE_COLUMN_DEPRECATION_DATE, + _FEATURE_COLUMN_DEPRECATION) + def _get_sparse_tensors(self, inputs, weight_collections=None, + trainable=None): + del weight_collections + del trainable + return CategoricalColumn.IdWeightPair(inputs.get(self), None) + class IdentityCategoricalColumn( CategoricalColumn, + fc_old._CategoricalColumn, # pylint: disable=protected-access collections.namedtuple('IdentityCategoricalColumn', ('key', 'number_buckets', 'default_value'))): """See `categorical_column_with_identity`.""" @property + def _is_v2_column(self): + return True + + @property def name(self): """See `FeatureColumn` base class.""" return self.key @@ -3043,11 +3366,14 @@ class IdentityCategoricalColumn( """See `FeatureColumn` base class.""" return {self.key: parsing_ops.VarLenFeature(dtypes.int64)} - def transform_feature(self, transformation_cache, state_manager): - """Returns a SparseTensor with identity values.""" - input_tensor = _to_sparse_input_and_drop_ignore_values( - transformation_cache.get(self.key, state_manager)) + @property + @deprecation.deprecated(_FEATURE_COLUMN_DEPRECATION_DATE, + _FEATURE_COLUMN_DEPRECATION) + def _parse_example_spec(self): + return self.parse_example_spec + def _transform_input_tensor(self, input_tensor): + """Returns a SparseTensor with identity values.""" if not input_tensor.dtype.is_integer: raise ValueError( 'Invalid input, not integer. key: {} dtype: {}'.format( @@ -3082,25 +3408,57 @@ class IdentityCategoricalColumn( values=values, dense_shape=input_tensor.dense_shape) + def transform_feature(self, transformation_cache, state_manager): + """Returns a SparseTensor with identity values.""" + input_tensor = _to_sparse_input_and_drop_ignore_values( + transformation_cache.get(self.key, state_manager)) + return self._transform_input_tensor(input_tensor) + + @deprecation.deprecated(_FEATURE_COLUMN_DEPRECATION_DATE, + _FEATURE_COLUMN_DEPRECATION) + def _transform_feature(self, inputs): + input_tensor = _to_sparse_input_and_drop_ignore_values(inputs.get(self.key)) + return self._transform_input_tensor(input_tensor) + @property def num_buckets(self): """Returns number of buckets in this sparse feature.""" return self.number_buckets + @property + @deprecation.deprecated(_FEATURE_COLUMN_DEPRECATION_DATE, + _FEATURE_COLUMN_DEPRECATION) + def _num_buckets(self): + return self.num_buckets + def get_sparse_tensors(self, transformation_cache, state_manager): """See `CategoricalColumn` base class.""" return CategoricalColumn.IdWeightPair( transformation_cache.get(self, state_manager), None) + @deprecation.deprecated(_FEATURE_COLUMN_DEPRECATION_DATE, + _FEATURE_COLUMN_DEPRECATION) + def _get_sparse_tensors(self, inputs, weight_collections=None, + trainable=None): + del weight_collections + del trainable + return CategoricalColumn.IdWeightPair(inputs.get(self), None) + class WeightedCategoricalColumn( CategoricalColumn, + fc_old._CategoricalColumn, # pylint: disable=protected-access collections.namedtuple( 'WeightedCategoricalColumn', ('categorical_column', 'weight_feature_key', 'dtype'))): """See `weighted_categorical_column`.""" @property + def _is_v2_column(self): + return (isinstance(self.categorical_column, FeatureColumn) and + self.categorical_column._is_v2_column) # pylint: disable=protected-access + + @property def name(self): """See `FeatureColumn` base class.""" return '{}_weighted_by_{}'.format( @@ -3117,14 +3475,28 @@ class WeightedCategoricalColumn( return config @property + @deprecation.deprecated(_FEATURE_COLUMN_DEPRECATION_DATE, + _FEATURE_COLUMN_DEPRECATION) + def _parse_example_spec(self): + config = self.categorical_column._parse_example_spec # pylint: disable=protected-access + if self.weight_feature_key in config: + raise ValueError('Parse config {} already exists for {}.'.format( + config[self.weight_feature_key], self.weight_feature_key)) + config[self.weight_feature_key] = parsing_ops.VarLenFeature(self.dtype) + return config + + @property def num_buckets(self): """See `DenseColumn` base class.""" return self.categorical_column.num_buckets - def transform_feature(self, transformation_cache, state_manager): - """Applies weights to tensor generated from `categorical_column`'.""" - weight_tensor = transformation_cache.get(self.weight_feature_key, - state_manager) + @property + @deprecation.deprecated(_FEATURE_COLUMN_DEPRECATION_DATE, + _FEATURE_COLUMN_DEPRECATION) + def _num_buckets(self): + return self.categorical_column._num_buckets # pylint: disable=protected-access + + def _transform_weight_tensor(self, weight_tensor): if weight_tensor is None: raise ValueError('Missing weights {}.'.format(self.weight_feature_key)) weight_tensor = sparse_tensor_lib.convert_to_tensor_or_sparse_tensor( @@ -3138,27 +3510,63 @@ class WeightedCategoricalColumn( weight_tensor, ignore_value=0.0) if not weight_tensor.dtype.is_floating: weight_tensor = math_ops.to_float(weight_tensor) + return weight_tensor + + def transform_feature(self, transformation_cache, state_manager): + """Applies weights to tensor generated from `categorical_column`'.""" + weight_tensor = transformation_cache.get(self.weight_feature_key, + state_manager) + weight_tensor = self._transform_weight_tensor(weight_tensor) return (transformation_cache.get(self.categorical_column, state_manager), weight_tensor) + @deprecation.deprecated(_FEATURE_COLUMN_DEPRECATION_DATE, + _FEATURE_COLUMN_DEPRECATION) + def _transform_feature(self, inputs): + """Applies weights to tensor generated from `categorical_column`'.""" + weight_tensor = inputs.get(self.weight_feature_key) + weight_tensor = self._transform_weight_tensor(weight_tensor) + return (inputs.get(self.categorical_column), weight_tensor) + def get_sparse_tensors(self, transformation_cache, state_manager): """See `CategoricalColumn` base class.""" tensors = transformation_cache.get(self, state_manager) return CategoricalColumn.IdWeightPair(tensors[0], tensors[1]) + @deprecation.deprecated(_FEATURE_COLUMN_DEPRECATION_DATE, + _FEATURE_COLUMN_DEPRECATION) + def _get_sparse_tensors(self, inputs, weight_collections=None, + trainable=None): + del weight_collections + del trainable + tensors = inputs.get(self) + return CategoricalColumn.IdWeightPair(tensors[0], tensors[1]) + class CrossedColumn( CategoricalColumn, + fc_old._CategoricalColumn, # pylint: disable=protected-access collections.namedtuple('CrossedColumn', ('keys', 'hash_bucket_size', 'hash_key'))): """See `crossed_column`.""" @property + def _is_v2_column(self): + for key in _collect_leaf_level_keys(self): + if isinstance(key, six.string_types): + continue + if not isinstance(key, FeatureColumn): + return False + if not key._is_v2_column: # pylint: disable=protected-access + return False + return True + + @property def name(self): """See `FeatureColumn` base class.""" feature_names = [] for key in _collect_leaf_level_keys(self): - if isinstance(key, FeatureColumn): + if isinstance(key, (FeatureColumn, fc_old._FeatureColumn)): # pylint: disable=protected-access feature_names.append(key.name) else: # key must be a string feature_names.append(key) @@ -3171,17 +3579,25 @@ class CrossedColumn( for key in self.keys: if isinstance(key, FeatureColumn): config.update(key.parse_example_spec) + elif isinstance(key, fc_old._FeatureColumn): # pylint: disable=protected-access + config.update(key._parse_example_spec) # pylint: disable=protected-access else: # key must be a string config.update({key: parsing_ops.VarLenFeature(dtypes.string)}) return config + @property + @deprecation.deprecated(_FEATURE_COLUMN_DEPRECATION_DATE, + _FEATURE_COLUMN_DEPRECATION) + def _parse_example_spec(self): + return self.parse_example_spec + def transform_feature(self, transformation_cache, state_manager): """Generates a hashed sparse cross from the input tensors.""" feature_tensors = [] for key in _collect_leaf_level_keys(self): if isinstance(key, six.string_types): feature_tensors.append(transformation_cache.get(key, state_manager)) - elif isinstance(key, CategoricalColumn): + elif isinstance(key, (fc_old._CategoricalColumn, CategoricalColumn)): # pylint: disable=protected-access ids_and_weights = key.get_sparse_tensors(transformation_cache, state_manager) if ids_and_weights.weight_tensor is not None: @@ -3197,16 +3613,54 @@ class CrossedColumn( num_buckets=self.hash_bucket_size, hash_key=self.hash_key) + @deprecation.deprecated(_FEATURE_COLUMN_DEPRECATION_DATE, + _FEATURE_COLUMN_DEPRECATION) + def _transform_feature(self, inputs): + """Generates a hashed sparse cross from the input tensors.""" + feature_tensors = [] + for key in _collect_leaf_level_keys(self): + if isinstance(key, six.string_types): + feature_tensors.append(inputs.get(key)) + elif isinstance(key, (CategoricalColumn, fc_old._CategoricalColumn)): # pylint: disable=protected-access + ids_and_weights = key._get_sparse_tensors(inputs) # pylint: disable=protected-access + if ids_and_weights.weight_tensor is not None: + raise ValueError( + 'crossed_column does not support weight_tensor, but the given ' + 'column populates weight_tensor. ' + 'Given column: {}'.format(key.name)) + feature_tensors.append(ids_and_weights.id_tensor) + else: + raise ValueError('Unsupported column type. Given: {}'.format(key)) + return sparse_ops.sparse_cross_hashed( + inputs=feature_tensors, + num_buckets=self.hash_bucket_size, + hash_key=self.hash_key) + @property def num_buckets(self): """Returns number of buckets in this sparse feature.""" return self.hash_bucket_size + @property + @deprecation.deprecated(_FEATURE_COLUMN_DEPRECATION_DATE, + _FEATURE_COLUMN_DEPRECATION) + def _num_buckets(self): + return self.num_buckets + def get_sparse_tensors(self, transformation_cache, state_manager): """See `CategoricalColumn` base class.""" return CategoricalColumn.IdWeightPair( transformation_cache.get(self, state_manager), None) + @deprecation.deprecated(_FEATURE_COLUMN_DEPRECATION_DATE, + _FEATURE_COLUMN_DEPRECATION) + def _get_sparse_tensors(self, inputs, weight_collections=None, + trainable=None): + """See `CategoricalColumn` base class.""" + del weight_collections + del trainable + return CategoricalColumn.IdWeightPair(inputs.get(self), None) + def _collect_leaf_level_keys(cross): """Collects base keys by expanding all nested crosses. @@ -3382,9 +3836,12 @@ def _prune_invalid_weights(sparse_ids, sparse_weights): return sparse_ids, sparse_weights -class IndicatorColumn(DenseColumn, SequenceDenseColumn, - collections.namedtuple('IndicatorColumn', - ('categorical_column'))): +class IndicatorColumn( + DenseColumn, + SequenceDenseColumn, + fc_old._DenseColumn, # pylint: disable=protected-access + fc_old._SequenceDenseColumn, # pylint: disable=protected-access + collections.namedtuple('IndicatorColumn', ('categorical_column'))): """Represents a one-hot column for use in deep networks. Args: @@ -3393,27 +3850,16 @@ class IndicatorColumn(DenseColumn, SequenceDenseColumn, """ @property + def _is_v2_column(self): + return (isinstance(self.categorical_column, FeatureColumn) and + self.categorical_column._is_v2_column) # pylint: disable=protected-access + + @property def name(self): """See `FeatureColumn` base class.""" return '{}_indicator'.format(self.categorical_column.name) - def transform_feature(self, transformation_cache, state_manager): - """Returns dense `Tensor` representing feature. - - Args: - transformation_cache: A `FeatureTransformationCache` object to access - features. - state_manager: A `StateManager` to create / access resources such as - lookup tables. - - Returns: - Transformed feature `Tensor`. - - Raises: - ValueError: if input rank is not known at graph building time. - """ - id_weight_pair = self.categorical_column.get_sparse_tensors( - transformation_cache, state_manager) + def _transform_id_weight_pair(self, id_weight_pair): id_tensor = id_weight_pair.id_tensor weight_tensor = id_weight_pair.weight_tensor @@ -3422,7 +3868,7 @@ class IndicatorColumn(DenseColumn, SequenceDenseColumn, weighted_column = sparse_ops.sparse_merge( sp_ids=id_tensor, sp_values=weight_tensor, - vocab_size=int(self.variable_shape[-1])) + vocab_size=int(self._variable_shape[-1])) # Remove (?, -1) index weighted_column = sparse_ops.sparse_slice(weighted_column, [0, 0], weighted_column.dense_shape) @@ -3435,22 +3881,62 @@ class IndicatorColumn(DenseColumn, SequenceDenseColumn, # input_layer are float32. one_hot_id_tensor = array_ops.one_hot( dense_id_tensor, - depth=self.variable_shape[-1], + depth=self._variable_shape[-1], on_value=1.0, off_value=0.0) # Reduce to get a multi-hot per example. return math_ops.reduce_sum(one_hot_id_tensor, axis=[-2]) + def transform_feature(self, transformation_cache, state_manager): + """Returns dense `Tensor` representing feature. + + Args: + transformation_cache: A `FeatureTransformationCache` object to access + features. + state_manager: A `StateManager` to create / access resources such as + lookup tables. + + Returns: + Transformed feature `Tensor`. + + Raises: + ValueError: if input rank is not known at graph building time. + """ + id_weight_pair = self.categorical_column.get_sparse_tensors( + transformation_cache, state_manager) + return self._transform_id_weight_pair(id_weight_pair) + + @deprecation.deprecated(_FEATURE_COLUMN_DEPRECATION_DATE, + _FEATURE_COLUMN_DEPRECATION) + def _transform_feature(self, inputs): + id_weight_pair = self.categorical_column._get_sparse_tensors(inputs) # pylint: disable=protected-access + return self._transform_id_weight_pair(id_weight_pair) + @property def parse_example_spec(self): """See `FeatureColumn` base class.""" return self.categorical_column.parse_example_spec @property + @deprecation.deprecated(_FEATURE_COLUMN_DEPRECATION_DATE, + _FEATURE_COLUMN_DEPRECATION) + def _parse_example_spec(self): + return self.categorical_column._parse_example_spec # pylint: disable=protected-access + + @property def variable_shape(self): """Returns a `TensorShape` representing the shape of the dense `Tensor`.""" - return tensor_shape.TensorShape([1, self.categorical_column.num_buckets]) + if isinstance(self.categorical_column, FeatureColumn): + return tensor_shape.TensorShape([1, self.categorical_column.num_buckets]) + else: + return tensor_shape.TensorShape([1, self.categorical_column._num_buckets]) # pylint: disable=protected-access + + @property + @deprecation.deprecated(_FEATURE_COLUMN_DEPRECATION_DATE, + _FEATURE_COLUMN_DEPRECATION) + def _variable_shape(self): + return tensor_shape.TensorShape([1, self.categorical_column._num_buckets]) # pylint: disable=protected-access def get_dense_tensor(self, transformation_cache, state_manager): """Returns dense `Tensor` representing feature. @@ -3481,6 +3967,27 @@ class IndicatorColumn(DenseColumn, SequenceDenseColumn, # representation created by transform_feature. return transformation_cache.get(self, state_manager) + @deprecation.deprecated(_FEATURE_COLUMN_DEPRECATION_DATE, + _FEATURE_COLUMN_DEPRECATION) + def _get_dense_tensor(self, inputs, weight_collections=None, trainable=None): + del weight_collections + del trainable + if isinstance( + self.categorical_column, + (SequenceCategoricalColumn, fc_old._SequenceCategoricalColumn)): # pylint: disable=protected-access + raise ValueError( + 'In indicator_column: {}. ' + 'categorical_column must not be of type _SequenceCategoricalColumn. ' + 'Suggested fix A: If you wish to use input_layer, use a ' + 'non-sequence categorical_column_with_*. ' + 'Suggested fix B: If you wish to create sequence input, use ' + 'sequence_input_layer instead of input_layer. ' + 'Given (type {}): {}'.format(self.name, type(self.categorical_column), + self.categorical_column)) + # Feature has been already transformed. Return the intermediate + # representation created by transform_feature. + return inputs.get(self) + def get_sequence_dense_tensor(self, transformation_cache, state_manager): """See `SequenceDenseColumn` base class.""" if not isinstance(self.categorical_column, SequenceCategoricalColumn): @@ -3496,7 +4003,36 @@ class IndicatorColumn(DenseColumn, SequenceDenseColumn, dense_tensor = transformation_cache.get(self, state_manager) sparse_tensors = self.categorical_column.get_sparse_tensors( transformation_cache, state_manager) - sequence_length = _sequence_length_from_sparse_tensor( + sequence_length = fc_old._sequence_length_from_sparse_tensor( # pylint: disable=protected-access + sparse_tensors.id_tensor) + return SequenceDenseColumn.TensorSequenceLengthPair( + dense_tensor=dense_tensor, sequence_length=sequence_length) + + @deprecation.deprecated(_FEATURE_COLUMN_DEPRECATION_DATE, + _FEATURE_COLUMN_DEPRECATION) + def _get_sequence_dense_tensor(self, + inputs, + weight_collections=None, + trainable=None): + # Do nothing with weight_collections and trainable since no variables are + # created in this function. + del weight_collections + del trainable + if not isinstance( + self.categorical_column, + (SequenceCategoricalColumn, fc_old._SequenceCategoricalColumn)): # pylint: disable=protected-access + raise ValueError( + 'In indicator_column: {}. ' + 'categorical_column must be of type _SequenceCategoricalColumn ' + 'to use sequence_input_layer. ' + 'Suggested fix: Use one of sequence_categorical_column_with_*. ' + 'Given (type {}): {}'.format(self.name, type(self.categorical_column), + self.categorical_column)) + # Feature has been already transformed. Return the intermediate + # representation created by _transform_feature. + dense_tensor = inputs.get(self) + sparse_tensors = self.categorical_column._get_sparse_tensors(inputs) # pylint: disable=protected-access + sequence_length = fc_old._sequence_length_from_sparse_tensor( # pylint: disable=protected-access sparse_tensors.id_tensor) return SequenceDenseColumn.TensorSequenceLengthPair( dense_tensor=dense_tensor, sequence_length=sequence_length) @@ -3518,28 +4054,19 @@ def _verify_static_batch_size_equality(tensors, columns): expected_batch_size, tensors[i].shape[0])) -def _sequence_length_from_sparse_tensor(sp_tensor, num_elements=1): - """Returns a [batch_size] Tensor with per-example sequence length.""" - with ops.name_scope(None, 'sequence_length') as name_scope: - row_ids = sp_tensor.indices[:, 0] - column_ids = sp_tensor.indices[:, 1] - column_ids += array_ops.ones_like(column_ids) - seq_length = math_ops.to_int64( - math_ops.segment_max(column_ids, segment_ids=row_ids) / num_elements) - # If the last n rows do not have ids, seq_length will have shape - # [batch_size - n]. Pad the remaining values with zeros. - n_pad = array_ops.shape(sp_tensor)[:1] - array_ops.shape(seq_length)[:1] - padding = array_ops.zeros(n_pad, dtype=seq_length.dtype) - return array_ops.concat([seq_length, padding], axis=0, name=name_scope) - - -class SequenceCategoricalColumn(FeatureColumn, - collections.namedtuple( - 'SequenceCategoricalColumn', - ('categorical_column'))): +class SequenceCategoricalColumn( + FeatureColumn, + fc_old._CategoricalColumn, # pylint: disable=protected-access + collections.namedtuple('SequenceCategoricalColumn', + ('categorical_column'))): """Represents sequences of categorical data.""" @property + def _is_v2_column(self): + return (isinstance(self.categorical_column, FeatureColumn) and + self.categorical_column._is_v2_column) # pylint: disable=protected-access + + @property def name(self): """See `FeatureColumn` base class.""" return self.categorical_column.name @@ -3549,16 +4076,46 @@ class SequenceCategoricalColumn(FeatureColumn, """See `FeatureColumn` base class.""" return self.categorical_column.parse_example_spec + @property + @deprecation.deprecated(_FEATURE_COLUMN_DEPRECATION_DATE, + _FEATURE_COLUMN_DEPRECATION) + def _parse_example_spec(self): + return self.categorical_column._parse_example_spec # pylint: disable=protected-access + def transform_feature(self, transformation_cache, state_manager): """See `FeatureColumn` base class.""" return self.categorical_column.transform_feature(transformation_cache, state_manager) + @deprecation.deprecated(_FEATURE_COLUMN_DEPRECATION_DATE, + _FEATURE_COLUMN_DEPRECATION) + def _transform_feature(self, inputs): + return self.categorical_column._transform_feature(inputs) # pylint: disable=protected-access + @property def num_buckets(self): """Returns number of buckets in this sparse feature.""" return self.categorical_column.num_buckets + @property + @deprecation.deprecated(_FEATURE_COLUMN_DEPRECATION_DATE, + _FEATURE_COLUMN_DEPRECATION) + def _num_buckets(self): + return self.categorical_column._num_buckets # pylint: disable=protected-access + + def _get_sparse_tensors_helper(self, sparse_tensors): + id_tensor = sparse_tensors.id_tensor + weight_tensor = sparse_tensors.weight_tensor + # Expands third dimension, if necessary so that embeddings are not + # combined during embedding lookup. If the tensor is already 3D, leave + # as-is. + shape = array_ops.shape(id_tensor) + target_shape = [shape[0], shape[1], -1] + id_tensor = sparse_ops.sparse_reshape(id_tensor, target_shape) + if weight_tensor is not None: + weight_tensor = sparse_ops.sparse_reshape(weight_tensor, target_shape) + return CategoricalColumn.IdWeightPair(id_tensor, weight_tensor) + def get_sequence_sparse_tensors(self, transformation_cache, state_manager): """Returns an IdWeightPair. @@ -3580,27 +4137,11 @@ class SequenceCategoricalColumn(FeatureColumn, """ sparse_tensors = self.categorical_column.get_sparse_tensors( transformation_cache, state_manager) - id_tensor = sparse_tensors.id_tensor - weight_tensor = sparse_tensors.weight_tensor - # Expands final dimension, so that embeddings are not combined during - # embedding lookup. - check_id_rank = check_ops.assert_equal( - array_ops.rank(id_tensor), 2, - data=[ - 'Column {} expected ID tensor of rank 2. '.format(self.name), - 'id_tensor shape: ', array_ops.shape(id_tensor)]) - with ops.control_dependencies([check_id_rank]): - id_tensor = sparse_ops.sparse_reshape( - id_tensor, - shape=array_ops.concat([id_tensor.dense_shape, [1]], axis=0)) - if weight_tensor is not None: - check_weight_rank = check_ops.assert_equal( - array_ops.rank(weight_tensor), 2, - data=[ - 'Column {} expected weight tensor of rank 2.'.format(self.name), - 'weight_tensor shape:', array_ops.shape(weight_tensor)]) - with ops.control_dependencies([check_weight_rank]): - weight_tensor = sparse_ops.sparse_reshape( - weight_tensor, - shape=array_ops.concat([weight_tensor.dense_shape, [1]], axis=0)) - return CategoricalColumn.IdWeightPair(id_tensor, weight_tensor) + return self._get_sparse_tensors_helper(sparse_tensors) + + @deprecation.deprecated(_FEATURE_COLUMN_DEPRECATION_DATE, + _FEATURE_COLUMN_DEPRECATION) + def _get_sparse_tensors(self, inputs, weight_collections=None, + trainable=None): + sparse_tensors = self.categorical_column._get_sparse_tensors(inputs) # pylint: disable=protected-access + return self._get_sparse_tensors_helper(sparse_tensors) |