aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/python/feature_column
diff options
context:
space:
mode:
authorGravatar Yifei Feng <yifeif@google.com>2018-07-06 13:31:44 -0700
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2018-07-06 13:34:42 -0700
commit602f8fad24eded4a7fcf5289840e2c646afc1bd0 (patch)
tree2d0901e379f322f585a98d1bc76a0394be95474a /tensorflow/python/feature_column
parentcbeaf2947a9627fbf3aa2dceee465a50f81c0534 (diff)
Merge changes from github.
PiperOrigin-RevId: 203518000
Diffstat (limited to 'tensorflow/python/feature_column')
-rw-r--r--tensorflow/python/feature_column/feature_column.py169
-rw-r--r--tensorflow/python/feature_column/feature_column_test.py6
2 files changed, 12 insertions, 163 deletions
diff --git a/tensorflow/python/feature_column/feature_column.py b/tensorflow/python/feature_column/feature_column.py
index 40219e4b34..d091d2fe0a 100644
--- a/tensorflow/python/feature_column/feature_column.py
+++ b/tensorflow/python/feature_column/feature_column.py
@@ -2158,7 +2158,7 @@ def _create_categorical_column_weighted_sum(column,
initializer=init_ops.zeros_initializer(),
trainable=trainable,
collections=weight_collections)
- return _safe_embedding_lookup_sparse(
+ return embedding_ops.safe_embedding_lookup_sparse(
weight,
id_tensor,
sparse_weights=weight_tensor,
@@ -2594,7 +2594,7 @@ class _EmbeddingColumn(
})
# Return embedding lookup result.
- return _safe_embedding_lookup_sparse(
+ return embedding_ops.safe_embedding_lookup_sparse(
embedding_weights=embedding_weights,
sparse_ids=sparse_ids,
sparse_weights=sparse_weights,
@@ -2736,7 +2736,7 @@ class _SharedEmbeddingColumn(
})
# Return embedding lookup result.
- return _safe_embedding_lookup_sparse(
+ return embedding_ops.safe_embedding_lookup_sparse(
embedding_weights=embedding_weights,
sparse_ids=sparse_ids,
sparse_weights=sparse_weights,
@@ -3228,161 +3228,6 @@ def _collect_leaf_level_keys(cross):
return leaf_level_keys
-# TODO(zakaria): Move this to embedding_ops and make it public.
-def _safe_embedding_lookup_sparse(embedding_weights,
- sparse_ids,
- sparse_weights=None,
- combiner='mean',
- default_id=None,
- name=None,
- partition_strategy='div',
- max_norm=None):
- """Lookup embedding results, accounting for invalid IDs and empty features.
-
- The partitioned embedding in `embedding_weights` must all be the same shape
- except for the first dimension. The first dimension is allowed to vary as the
- vocabulary size is not necessarily a multiple of `P`. `embedding_weights`
- may be a `PartitionedVariable` as returned by using `tf.get_variable()` with a
- partitioner.
-
- Invalid IDs (< 0) are pruned from input IDs and weights, as well as any IDs
- with non-positive weight. For an entry with no features, the embedding vector
- for `default_id` is returned, or the 0-vector if `default_id` is not supplied.
-
- The ids and weights may be multi-dimensional. Embeddings are always aggregated
- along the last dimension.
-
- Args:
- embedding_weights: A list of `P` float `Tensor`s or values representing
- partitioned embedding `Tensor`s. Alternatively, a `PartitionedVariable`
- created by partitioning along dimension 0. The total unpartitioned
- shape should be `[e_0, e_1, ..., e_m]`, where `e_0` represents the
- vocab size and `e_1, ..., e_m` are the embedding dimensions.
- sparse_ids: `SparseTensor` of shape `[d_0, d_1, ..., d_n]` containing the
- ids. `d_0` is typically batch size.
- sparse_weights: `SparseTensor` of same shape as `sparse_ids`, containing
- float weights corresponding to `sparse_ids`, or `None` if all weights
- are be assumed to be 1.0.
- combiner: A string specifying how to combine embedding results for each
- entry. Currently "mean", "sqrtn" and "sum" are supported, with "mean"
- the default.
- default_id: The id to use for an entry with no features.
- name: A name for this operation (optional).
- partition_strategy: A string specifying the partitioning strategy.
- Currently `"div"` and `"mod"` are supported. Default is `"div"`.
- max_norm: If not `None`, all embeddings are l2-normalized to max_norm before
- combining.
-
-
- Returns:
- Dense `Tensor` of shape `[d_0, d_1, ..., d_{n-1}, e_1, ..., e_m]`.
-
- Raises:
- ValueError: if `embedding_weights` is empty.
- """
- if embedding_weights is None:
- raise ValueError('Missing embedding_weights %s.' % embedding_weights)
- if isinstance(embedding_weights, variables.PartitionedVariable):
- embedding_weights = list(embedding_weights) # get underlying Variables.
- if not isinstance(embedding_weights, list):
- embedding_weights = [embedding_weights]
- if len(embedding_weights) < 1:
- raise ValueError('Missing embedding_weights %s.' % embedding_weights)
-
- dtype = sparse_weights.dtype if sparse_weights is not None else None
- embedding_weights = [
- ops.convert_to_tensor(w, dtype=dtype) for w in embedding_weights
- ]
-
- with ops.name_scope(name, 'embedding_lookup',
- embedding_weights + [sparse_ids,
- sparse_weights]) as scope:
- # Reshape higher-rank sparse ids and weights to linear segment ids.
- original_shape = sparse_ids.dense_shape
- original_rank_dim = sparse_ids.dense_shape.get_shape()[0]
- original_rank = (
- array_ops.size(original_shape)
- if original_rank_dim.value is None
- else original_rank_dim.value)
- sparse_ids = sparse_ops.sparse_reshape(sparse_ids, [
- math_ops.reduce_prod(
- array_ops.slice(original_shape, [0], [original_rank - 1])),
- array_ops.gather(original_shape, original_rank - 1)])
- if sparse_weights is not None:
- sparse_weights = sparse_tensor_lib.SparseTensor(
- sparse_ids.indices,
- sparse_weights.values, sparse_ids.dense_shape)
-
- # Prune invalid ids and weights.
- sparse_ids, sparse_weights = _prune_invalid_ids(sparse_ids, sparse_weights)
- if combiner != 'sum':
- sparse_ids, sparse_weights = _prune_invalid_weights(
- sparse_ids, sparse_weights)
-
- # Fill in dummy values for empty features, if necessary.
- sparse_ids, is_row_empty = sparse_ops.sparse_fill_empty_rows(sparse_ids,
- default_id or
- 0)
- if sparse_weights is not None:
- sparse_weights, _ = sparse_ops.sparse_fill_empty_rows(sparse_weights, 1.0)
-
- result = embedding_ops.embedding_lookup_sparse(
- embedding_weights,
- sparse_ids,
- sparse_weights,
- combiner=combiner,
- partition_strategy=partition_strategy,
- name=None if default_id is None else scope,
- max_norm=max_norm)
-
- if default_id is None:
- # Broadcast is_row_empty to the same shape as embedding_lookup_result,
- # for use in Select.
- is_row_empty = array_ops.tile(
- array_ops.reshape(is_row_empty, [-1, 1]),
- array_ops.stack([1, array_ops.shape(result)[1]]))
-
- result = array_ops.where(is_row_empty,
- array_ops.zeros_like(result),
- result,
- name=scope)
-
- # Reshape back from linear ids back into higher-dimensional dense result.
- final_result = array_ops.reshape(
- result,
- array_ops.concat([
- array_ops.slice(
- math_ops.cast(original_shape, dtypes.int32), [0],
- [original_rank - 1]),
- array_ops.slice(array_ops.shape(result), [1], [-1])
- ], 0))
- final_result.set_shape(tensor_shape.unknown_shape(
- (original_rank_dim - 1).value).concatenate(result.get_shape()[1:]))
- return final_result
-
-
-def _prune_invalid_ids(sparse_ids, sparse_weights):
- """Prune invalid IDs (< 0) from the input ids and weights."""
- is_id_valid = math_ops.greater_equal(sparse_ids.values, 0)
- if sparse_weights is not None:
- is_id_valid = math_ops.logical_and(
- is_id_valid,
- array_ops.ones_like(sparse_weights.values, dtype=dtypes.bool))
- sparse_ids = sparse_ops.sparse_retain(sparse_ids, is_id_valid)
- if sparse_weights is not None:
- sparse_weights = sparse_ops.sparse_retain(sparse_weights, is_id_valid)
- return sparse_ids, sparse_weights
-
-
-def _prune_invalid_weights(sparse_ids, sparse_weights):
- """Prune invalid weights (< 0) from the input ids and weights."""
- if sparse_weights is not None:
- is_weights_valid = math_ops.greater(sparse_weights.values, 0)
- sparse_ids = sparse_ops.sparse_retain(sparse_ids, is_weights_valid)
- sparse_weights = sparse_ops.sparse_retain(sparse_weights, is_weights_valid)
- return sparse_ids, sparse_weights
-
-
class _IndicatorColumn(_DenseColumn, _SequenceDenseColumn,
collections.namedtuple('_IndicatorColumn',
['categorical_column'])):
@@ -3419,10 +3264,14 @@ class _IndicatorColumn(_DenseColumn, _SequenceDenseColumn,
sp_ids=id_tensor,
sp_values=weight_tensor,
vocab_size=int(self._variable_shape[-1]))
- # Remove (?, -1) index
+ # Remove (?, -1) index.
weighted_column = sparse_ops.sparse_slice(weighted_column, [0, 0],
weighted_column.dense_shape)
- return sparse_ops.sparse_tensor_to_dense(weighted_column)
+ # Use scatter_nd to merge duplicated indices if existed,
+ # instead of sparse_tensor_to_dense.
+ return array_ops.scatter_nd(weighted_column.indices,
+ weighted_column.values,
+ weighted_column.dense_shape)
dense_id_tensor = sparse_ops.sparse_tensor_to_dense(
id_tensor, default_value=-1)
diff --git a/tensorflow/python/feature_column/feature_column_test.py b/tensorflow/python/feature_column/feature_column_test.py
index 511205451c..5bb47bfa47 100644
--- a/tensorflow/python/feature_column/feature_column_test.py
+++ b/tensorflow/python/feature_column/feature_column_test.py
@@ -4580,12 +4580,12 @@ class IndicatorColumnTest(test.TestCase):
weights = fc.weighted_categorical_column(ids, 'weights')
indicator = fc.indicator_column(weights)
features = {
- 'ids': constant_op.constant([['c', 'b', 'a']]),
- 'weights': constant_op.constant([[2., 4., 6.]])
+ 'ids': constant_op.constant([['c', 'b', 'a', 'c']]),
+ 'weights': constant_op.constant([[2., 4., 6., 1.]])
}
indicator_tensor = _transform_features(features, [indicator])[indicator]
with _initialized_session():
- self.assertAllEqual([[6., 4., 2.]], indicator_tensor.eval())
+ self.assertAllEqual([[6., 4., 3.]], indicator_tensor.eval())
def test_transform_with_missing_value_in_weighted_column(self):
# Github issue 12583