From 153decedefc8da1fbd0717f4223b4b053e7aa517 Mon Sep 17 00:00:00 2001 From: Karmel Allison Date: Mon, 8 Oct 2018 10:36:38 -0700 Subject: Add support for SequenceExamples to sequence_feature_columns PiperOrigin-RevId: 216210141 --- tensorflow/python/feature_column/feature_column.py | 53 +++++++++++++--------- tensorflow/python/ops/parsing_ops.py | 13 +++--- 2 files changed, 38 insertions(+), 28 deletions(-) (limited to 'tensorflow/python') diff --git a/tensorflow/python/feature_column/feature_column.py b/tensorflow/python/feature_column/feature_column.py index 5352796174..28a8286544 100644 --- a/tensorflow/python/feature_column/feature_column.py +++ b/tensorflow/python/feature_column/feature_column.py @@ -2660,6 +2660,7 @@ class _EmbeddingColumn( inputs=inputs, weight_collections=weight_collections, trainable=trainable) + sparse_tensors = self.categorical_column._get_sparse_tensors(inputs) # pylint: disable=protected-access sequence_length = _sequence_length_from_sparse_tensor( sparse_tensors.id_tensor) @@ -3383,6 +3384,16 @@ class _IndicatorColumn(_DenseColumn, _SequenceDenseColumn, def _verify_static_batch_size_equality(tensors, columns): + """Validates that the first dim (batch size) of all tensors are equal or None. + + Args: + tensors: list of tensors to check. + columns: list of feature columns matching tensors. Will be used for error + messaging. + + Raises: + ValueError: if one of the tensors has a variant batch size + """ # bath_size is a tf.Dimension object. expected_batch_size = None for i in range(0, len(tensors)): @@ -3403,9 +3414,18 @@ def _sequence_length_from_sparse_tensor(sp_tensor, num_elements=1): with ops.name_scope(None, 'sequence_length') as name_scope: row_ids = sp_tensor.indices[:, 0] column_ids = sp_tensor.indices[:, 1] + # Add one to convert column indices to element length column_ids += array_ops.ones_like(column_ids) - seq_length = math_ops.to_int64( - math_ops.segment_max(column_ids, segment_ids=row_ids) / num_elements) + # Get the number of elements we will have per example/row + seq_length = math_ops.segment_max(column_ids, segment_ids=row_ids) + + # The raw values are grouped according to num_elements; + # how many entities will we have after grouping? + # Example: orig tensor [[1, 2], [3]], col_ids = (0, 1, 1), + # row_ids = (0, 0, 1), seq_length = [2, 1]. If num_elements = 2, + # these will get grouped, and the final seq_length is [1, 1] + seq_length = math_ops.to_int64(math_ops.ceil(seq_length / num_elements)) + # If the last n rows do not have ids, seq_length will have shape # [batch_size - n]. Pad the remaining values with zeros. n_pad = array_ops.shape(sp_tensor)[:1] - array_ops.shape(seq_length)[:1] @@ -3439,25 +3459,14 @@ class _SequenceCategoricalColumn( sparse_tensors = self.categorical_column._get_sparse_tensors(inputs) # pylint: disable=protected-access id_tensor = sparse_tensors.id_tensor weight_tensor = sparse_tensors.weight_tensor - # Expands final dimension, so that embeddings are not combined during - # embedding lookup. - check_id_rank = check_ops.assert_equal( - array_ops.rank(id_tensor), 2, - data=[ - 'Column {} expected ID tensor of rank 2. '.format(self.name), - 'id_tensor shape: ', array_ops.shape(id_tensor)]) - with ops.control_dependencies([check_id_rank]): - id_tensor = sparse_ops.sparse_reshape( - id_tensor, - shape=array_ops.concat([id_tensor.dense_shape, [1]], axis=0)) + + # Expands third dimension, if necessary so that embeddings are not + # combined during embedding lookup. If the tensor is already 3D, leave + # as-is. + shape = array_ops.shape(id_tensor) + target_shape = [shape[0], shape[1], -1] + id_tensor = sparse_ops.sparse_reshape(id_tensor, target_shape) if weight_tensor is not None: - check_weight_rank = check_ops.assert_equal( - array_ops.rank(weight_tensor), 2, - data=[ - 'Column {} expected weight tensor of rank 2.'.format(self.name), - 'weight_tensor shape:', array_ops.shape(weight_tensor)]) - with ops.control_dependencies([check_weight_rank]): - weight_tensor = sparse_ops.sparse_reshape( - weight_tensor, - shape=array_ops.concat([weight_tensor.dense_shape, [1]], axis=0)) + weight_tensor = sparse_ops.sparse_reshape(weight_tensor, target_shape) + return _CategoricalColumn.IdWeightPair(id_tensor, weight_tensor) diff --git a/tensorflow/python/ops/parsing_ops.py b/tensorflow/python/ops/parsing_ops.py index ff50fe0d09..a2da6412ed 100644 --- a/tensorflow/python/ops/parsing_ops.py +++ b/tensorflow/python/ops/parsing_ops.py @@ -217,21 +217,21 @@ def _features_to_raw_params(features, types): feature = features[key] if isinstance(feature, VarLenFeature): if VarLenFeature not in types: - raise ValueError("Unsupported VarLenFeature %s." % feature) + raise ValueError("Unsupported VarLenFeature %s." % (feature,)) if not feature.dtype: raise ValueError("Missing type for feature %s." % key) sparse_keys.append(key) sparse_types.append(feature.dtype) elif isinstance(feature, SparseFeature): if SparseFeature not in types: - raise ValueError("Unsupported SparseFeature %s." % feature) + raise ValueError("Unsupported SparseFeature %s." % (feature,)) if not feature.index_key: raise ValueError( - "Missing index_key for SparseFeature %s." % feature) + "Missing index_key for SparseFeature %s." % (feature,)) if not feature.value_key: raise ValueError( - "Missing value_key for SparseFeature %s." % feature) + "Missing value_key for SparseFeature %s." % (feature,)) if not feature.dtype: raise ValueError("Missing type for feature %s." % key) index_keys = feature.index_key @@ -260,7 +260,7 @@ def _features_to_raw_params(features, types): sparse_types.append(feature.dtype) elif isinstance(feature, FixedLenFeature): if FixedLenFeature not in types: - raise ValueError("Unsupported FixedLenFeature %s." % feature) + raise ValueError("Unsupported FixedLenFeature %s." % (feature,)) if not feature.dtype: raise ValueError("Missing type for feature %s." % key) if feature.shape is None: @@ -281,7 +281,8 @@ def _features_to_raw_params(features, types): dense_defaults[key] = feature.default_value elif isinstance(feature, FixedLenSequenceFeature): if FixedLenSequenceFeature not in types: - raise ValueError("Unsupported FixedLenSequenceFeature %s." % feature) + raise ValueError("Unsupported FixedLenSequenceFeature %s." % ( + feature,)) if not feature.dtype: raise ValueError("Missing type for feature %s." % key) if feature.shape is None: -- cgit v1.2.3