diff options
Diffstat (limited to 'tensorflow/python/feature_column/feature_column.py')
-rw-r--r-- | tensorflow/python/feature_column/feature_column.py | 53 |
1 files changed, 31 insertions, 22 deletions
diff --git a/tensorflow/python/feature_column/feature_column.py b/tensorflow/python/feature_column/feature_column.py index 5352796174..28a8286544 100644 --- a/tensorflow/python/feature_column/feature_column.py +++ b/tensorflow/python/feature_column/feature_column.py @@ -2660,6 +2660,7 @@ class _EmbeddingColumn( inputs=inputs, weight_collections=weight_collections, trainable=trainable) + sparse_tensors = self.categorical_column._get_sparse_tensors(inputs) # pylint: disable=protected-access sequence_length = _sequence_length_from_sparse_tensor( sparse_tensors.id_tensor) @@ -3383,6 +3384,16 @@ class _IndicatorColumn(_DenseColumn, _SequenceDenseColumn, def _verify_static_batch_size_equality(tensors, columns): + """Validates that the first dim (batch size) of all tensors are equal or None. + + Args: + tensors: list of tensors to check. + columns: list of feature columns matching tensors. Will be used for error + messaging. + + Raises: + ValueError: if one of the tensors has a variant batch size + """ # bath_size is a tf.Dimension object. expected_batch_size = None for i in range(0, len(tensors)): @@ -3403,9 +3414,18 @@ def _sequence_length_from_sparse_tensor(sp_tensor, num_elements=1): with ops.name_scope(None, 'sequence_length') as name_scope: row_ids = sp_tensor.indices[:, 0] column_ids = sp_tensor.indices[:, 1] + # Add one to convert column indices to element length column_ids += array_ops.ones_like(column_ids) - seq_length = math_ops.to_int64( - math_ops.segment_max(column_ids, segment_ids=row_ids) / num_elements) + # Get the number of elements we will have per example/row + seq_length = math_ops.segment_max(column_ids, segment_ids=row_ids) + + # The raw values are grouped according to num_elements; + # how many entities will we have after grouping? + # Example: orig tensor [[1, 2], [3]], col_ids = (0, 1, 1), + # row_ids = (0, 0, 1), seq_length = [2, 1]. If num_elements = 2, + # these will get grouped, and the final seq_length is [1, 1] + seq_length = math_ops.to_int64(math_ops.ceil(seq_length / num_elements)) + # If the last n rows do not have ids, seq_length will have shape # [batch_size - n]. Pad the remaining values with zeros. n_pad = array_ops.shape(sp_tensor)[:1] - array_ops.shape(seq_length)[:1] @@ -3439,25 +3459,14 @@ class _SequenceCategoricalColumn( sparse_tensors = self.categorical_column._get_sparse_tensors(inputs) # pylint: disable=protected-access id_tensor = sparse_tensors.id_tensor weight_tensor = sparse_tensors.weight_tensor - # Expands final dimension, so that embeddings are not combined during - # embedding lookup. - check_id_rank = check_ops.assert_equal( - array_ops.rank(id_tensor), 2, - data=[ - 'Column {} expected ID tensor of rank 2. '.format(self.name), - 'id_tensor shape: ', array_ops.shape(id_tensor)]) - with ops.control_dependencies([check_id_rank]): - id_tensor = sparse_ops.sparse_reshape( - id_tensor, - shape=array_ops.concat([id_tensor.dense_shape, [1]], axis=0)) + + # Expands third dimension, if necessary so that embeddings are not + # combined during embedding lookup. If the tensor is already 3D, leave + # as-is. + shape = array_ops.shape(id_tensor) + target_shape = [shape[0], shape[1], -1] + id_tensor = sparse_ops.sparse_reshape(id_tensor, target_shape) if weight_tensor is not None: - check_weight_rank = check_ops.assert_equal( - array_ops.rank(weight_tensor), 2, - data=[ - 'Column {} expected weight tensor of rank 2.'.format(self.name), - 'weight_tensor shape:', array_ops.shape(weight_tensor)]) - with ops.control_dependencies([check_weight_rank]): - weight_tensor = sparse_ops.sparse_reshape( - weight_tensor, - shape=array_ops.concat([weight_tensor.dense_shape, [1]], axis=0)) + weight_tensor = sparse_ops.sparse_reshape(weight_tensor, target_shape) + return _CategoricalColumn.IdWeightPair(id_tensor, weight_tensor) |