diff options
Diffstat (limited to 'tensorflow/contrib/layers/python/layers/embedding_ops.py')
-rw-r--r-- | tensorflow/contrib/layers/python/layers/embedding_ops.py | 357 |
1 files changed, 356 insertions, 1 deletions
diff --git a/tensorflow/contrib/layers/python/layers/embedding_ops.py b/tensorflow/contrib/layers/python/layers/embedding_ops.py index f0ed31d1d1..b1a7f7ee59 100644 --- a/tensorflow/contrib/layers/python/layers/embedding_ops.py +++ b/tensorflow/contrib/layers/python/layers/embedding_ops.py @@ -17,24 +17,31 @@ from __future__ import absolute_import from __future__ import division from __future__ import print_function +from six.moves import xrange # pylint: disable=redefined-builtin + from tensorflow.contrib.framework.python.framework import tensor_util as contrib_tensor_util from tensorflow.contrib.layers.python.ops import sparse_feature_cross_op +from tensorflow.python.framework import constant_op from tensorflow.python.framework import dtypes from tensorflow.python.framework import ops from tensorflow.python.framework import sparse_tensor from tensorflow.python.framework import tensor_shape from tensorflow.python.ops import array_ops +from tensorflow.python.ops import clip_ops from tensorflow.python.ops import control_flow_ops +from tensorflow.python.ops import data_flow_ops from tensorflow.python.ops import embedding_ops from tensorflow.python.ops import math_ops +from tensorflow.python.ops import resource_variable_ops from tensorflow.python.ops import sparse_ops from tensorflow.python.ops import variables from tensorflow.python.platform import tf_logging as logging __all__ = [ "safe_embedding_lookup_sparse", "scattered_embedding_lookup", - "scattered_embedding_lookup_sparse", "embedding_lookup_unique" + "scattered_embedding_lookup_sparse", "embedding_lookup_unique", + "embedding_lookup_sparse_with_distributed_aggregation" ] @@ -548,3 +555,351 @@ def _sampled_scattered_embedding_lookup_sparse(params, return math_ops.unsorted_segment_sum(embeddings, segment_ids, num_segments=num_segments, name=name_scope) + + +def embedding_lookup_sparse_with_distributed_aggregation( + params, + sp_ids, + sp_weights, + partition_strategy="mod", + name=None, + combiner=None, + max_norm=None): + """Computes embeddings for the given ids and weights. + + Embeddings belonging to same param are aggregated on that device first. This + op is intended to decrease data transmission and improve parallelism. See + `tf.nn.embedding_lookup_sparse` for the functionality and example of this op. + + Args: + params: A single tensor representing the complete embedding tensor, + or a list of P tensors all of same shape except for the first dimension, + representing sharded embedding tensors. Alternatively, a + `PartitionedVariable`, created by partitioning along dimension 0. Each + element must be appropriately sized for the given `partition_strategy`. + sp_ids: N x M SparseTensor of int64 ids (typically from FeatureValueToId), + where N is typically batch size and M is arbitrary. + sp_weights: either a SparseTensor of float / double weights, or None to + indicate all weights should be taken to be 1. If specified, sp_weights + must have exactly the same shape and indices as sp_ids. + partition_strategy: A string specifying the partitioning strategy, relevant + if `len(params) > 1`. Currently `"div"` and `"mod"` are supported. Default + is `"mod"`. See `tf.nn.embedding_lookup` for more details. + name: Optional name for the op. + combiner: A string specifying the reduction op. Currently "mean", "sqrtn" + and "sum" are supported. + "sum" computes the weighted sum of the embedding results for each row. + "mean" is the weighted sum divided by the total weight. + "sqrtn" is the weighted sum divided by the square root of the sum of the + squares of the weights. + max_norm: If not None, each embedding is normalized to have l2 norm equal + to max_norm before combining. + + Returns: + A dense tensor representing the combined embeddings for the + sparse ids. For each row in the dense tensor represented by sp_ids, the op + looks up the embeddings for all ids in that row, multiplies them by the + corresponding weight, and combines these embeddings as specified. + + Raises: + TypeError: If sp_ids is not a SparseTensor, or if sp_weights is neither + None nor SparseTensor. + ValueError: If combiner is not one of {"mean", "sqrtn", "sum"}. + """ + if combiner is None: + logging.warn("The default value of combiner will change from \"mean\" " + "to \"sqrtn\" after 2016/11/01.") + combiner = "mean" + if combiner not in ("mean", "sqrtn", "sum"): + raise ValueError("combiner must be one of 'mean', 'sqrtn' or 'sum'") + if isinstance(params, variables.PartitionedVariable): + params = list(params) # Iterate to get the underlying Variables. + if not isinstance(params, list): + params = [params] + if not isinstance(sp_ids, sparse_tensor.SparseTensor): + raise TypeError("sp_ids must be SparseTensor") + ignore_weights = sp_weights is None + if not ignore_weights: + if not isinstance(sp_weights, sparse_tensor.SparseTensor): + raise TypeError("sp_weights must be either None or SparseTensor") + sp_ids.values.get_shape().assert_is_compatible_with( + sp_weights.values.get_shape()) + sp_ids.indices.get_shape().assert_is_compatible_with( + sp_weights.indices.get_shape()) + sp_ids.dense_shape.get_shape().assert_is_compatible_with( + sp_weights.dense_shape.get_shape()) + # TODO(yleon): Add enhanced node assertions to verify that sp_ids and + # sp_weights have equal indices and shapes. + + with ops.name_scope(name, "embedding_lookup_sparse", + params + [sp_ids]) as name: + segment_ids = sp_ids.indices[:, 0] + if segment_ids.dtype != dtypes.int32: + segment_ids = math_ops.cast(segment_ids, dtypes.int32) + + ids = sp_ids.values + if ignore_weights: + ids, idx = array_ops.unique(ids) + else: + idx = None + + weights = None if ignore_weights else sp_weights.values + embeddings = _embedding_lookup_with_distributed_aggregation( + params, + ids, + partition_strategy=partition_strategy, + max_norm=max_norm, + weights=weights, + idx=idx, + segment_ids=segment_ids) + # Set weights to all one if ignore weights. + if ignore_weights: + weights = array_ops.fill([array_ops.shape(segment_ids)[0]], 1) + if weights.dtype != embeddings.dtype: + weights = math_ops.cast(weights, embeddings.dtype) + # Reshape weights. + ones = array_ops.fill( + array_ops.expand_dims(array_ops.rank(embeddings) - 1, 0), 1) + bcast_weights_shape = array_ops.concat([array_ops.shape(weights), ones], 0) + orig_weights_shape = weights.get_shape() + weights = array_ops.reshape(weights, bcast_weights_shape) + if embeddings.get_shape().ndims is not None: + weights.set_shape( + orig_weights_shape.concatenate( + [1 for _ in range(embeddings.get_shape().ndims - 1)])) + + if combiner == "mean": + weight_sum = math_ops.segment_sum(weights, segment_ids) + embeddings = math_ops.div(embeddings, weight_sum) + elif combiner == "sqrtn": + weights_squared = math_ops.pow(weights, 2) + weight_sum = math_ops.segment_sum(weights_squared, segment_ids) + weight_sum_sqrt = math_ops.sqrt(weight_sum) + embeddings = math_ops.div(embeddings, weight_sum_sqrt) + elif combiner != "sum": + assert False, "Unrecognized combiner" + return embeddings + + +def _do_gather(params, ids, validate_indices=True, name=None): + """Deals with doing gather differently for resource variables.""" + if isinstance(params, resource_variable_ops.ResourceVariable): + return params.sparse_read(ids, name=name) + return array_ops.gather( + params, ids, name=name, validate_indices=validate_indices) + + +def _embedding_lookup_with_distributed_aggregation(params, + ids, + partition_strategy="mod", + name=None, + validate_indices=True, + max_norm=None, + weights=None, + idx=None, + segment_ids=None): + """Lookup helper for embedding_lookup_sparse_with_distributed_aggregation.""" + if params is None or params == []: # pylint: disable=g-explicit-bool-comparison + raise ValueError("Need at least one param") + if isinstance(params, variables.PartitionedVariable): + params = list(params) # Iterate to get the underlying Variables. + if not isinstance(params, list): + params = [params] + + def maybe_normalize(x): + if max_norm is not None: + if x.get_shape().ndims is not None: + ndims = x.get_shape().ndims + else: + ndims = array_ops.size(array_ops.shape(x)) + return clip_ops.clip_by_norm(x, max_norm, axes=list(range(1, ndims))) + return x + + with ops.name_scope(name, "embedding_lookup_with_distributed_aggregation", + params + [ids]) as name: + np = len(params) # Number of partitions + # Preserve the resource variable status to avoid accidental dense reads. + if not any( + isinstance(p, resource_variable_ops.ResourceVariable) for p in params): + params = ops.convert_n_to_tensor_or_indexed_slices(params, name="params") + if np == 1: + with ops.colocate_with(params[0]): + ret = maybe_normalize( + _do_gather(params[0], ids, validate_indices=validate_indices)) + ignore_weights = weights is None + if not ignore_weights: + if weights.dtype != ret.dtype: + weights = math_ops.cast(weights, ret.dtype) + # Reshape to allow broadcast + ones = array_ops.fill( + array_ops.expand_dims(array_ops.rank(ret) - 1, 0), 1) + bcast_weights_shape = array_ops.concat( + [array_ops.shape(weights), ones], 0) + orig_weights_shape = weights.get_shape() + weights = array_ops.reshape(weights, bcast_weights_shape) + # Set weights shape after reshape + if ret.get_shape().ndims is not None: + weights.set_shape( + orig_weights_shape.concatenate( + [1 for _ in range(ret.get_shape().ndims - 1)])) + ret *= weights + return math_ops.segment_sum(ret, segment_ids, name=name) + else: + return math_ops.sparse_segment_sum(ret, idx, segment_ids, name=name) + else: + ids = ops.convert_to_tensor(ids, name="ids") + flat_ids = array_ops.reshape(ids, [-1]) + original_indices = math_ops.range(array_ops.size(flat_ids)) + + # Create p_assignments and set new_ids depending on the strategy. + if partition_strategy == "mod": + p_assignments = flat_ids % np + new_ids = flat_ids // np + elif partition_strategy == "div": + # Compute num_total_ids as the sum of dim-0 of params, then assign to + # partitions based on a constant number of ids per partition. Optimize + # if we already know the full shape statically. + dim_0_size = params[0].get_shape()[0] + for p in xrange(1, np): + dim_0_size += params[p].get_shape()[0] + if dim_0_size.value: + num_total_ids = constant_op.constant(dim_0_size.value, flat_ids.dtype) + else: + dim_0_sizes = [] + for p in xrange(np): + if params[p].get_shape()[0].value is not None: + dim_0_sizes.append(params[p].get_shape()[0].value) + else: + with ops.colocate_with(params[p]): + dim_0_sizes.append(array_ops.shape(params[p])[0]) + num_total_ids = math_ops.reduce_sum( + math_ops.cast(array_ops.stack(dim_0_sizes), flat_ids.dtype)) + ids_per_partition = num_total_ids // np + extras = num_total_ids % np + + p_assignments = math_ops.maximum(flat_ids // (ids_per_partition + 1), ( + flat_ids - extras) // ids_per_partition) + + # Emulate a conditional using a boolean indicator tensor + is_in_first_extras_partitions = math_ops.cast(p_assignments < extras, + flat_ids.dtype) + new_ids = (is_in_first_extras_partitions * (flat_ids % + (ids_per_partition + 1)) + + (1 - is_in_first_extras_partitions) * ( + (flat_ids - extras) % ids_per_partition)) + else: + raise ValueError("Unrecognized partition strategy: " + + partition_strategy) + + # Cast partition assignments to int32 for use in dynamic_partition. + # There really should not be more than 2^32 partitions. + p_assignments = math_ops.cast(p_assignments, dtypes.int32) + # Partition list of ids based on assignments into np separate lists + gather_ids = data_flow_ops.dynamic_partition(new_ids, p_assignments, np) + # Similarly, partition the original indices. + pindices = data_flow_ops.dynamic_partition(original_indices, + p_assignments, np) + # Do np separate lookups, finding embeddings for plist[p] in params[p] + partitioned_result = [] + for p in xrange(np): + with ops.colocate_with(params[p]): + partitioned_result.append( + _do_gather( + params[p], gather_ids[p], validate_indices=validate_indices)) + + ignore_weights = weights is None + if not ignore_weights: + # Partition weights according to pindices. + partitioned_weight = [] + for p in xrange(np): + partitioned_weight.append(array_ops.gather(weights, pindices[p])) + # Reshape each partition result. + element_shape = params[0].get_shape()[1:] + for p in params[1:]: + element_shape = element_shape.merge_with(p.get_shape()[1:]) + if element_shape.is_fully_defined(): + for p in xrange(np): + with ops.colocate_with(params[p]): + partitioned_result[p] = array_ops.reshape( + partitioned_result[p], + array_ops.concat([array_ops.shape(pindices[p]), element_shape], + 0)) + else: + with ops.colocate_with(params[0]): + params_shape = array_ops.shape(params[0]) + for p in xrange(np): + with ops.colocate_with(params[p]): + partitioned_result[p] = array_ops.reshape( + partitioned_result[p], + array_ops.concat([ + array_ops.shape(pindices[p]), array_ops.slice( + params_shape, [1], [-1]) + ], 0)) + # Normalize each partition result. + for p in xrange(np): + with ops.colocate_with(params[p]): + partitioned_result[p] = maybe_normalize(partitioned_result[p]) + if not ignore_weights: + # Multiply each partition result with partition weights. + for p in xrange(np): + with ops.colocate_with(params[p]): + if partitioned_weight[p].dtype != partitioned_result[p].dtype: + partitioned_weight[p] = math_ops.cast(partitioned_weight[p], + partitioned_result[p].dtype) + # Reshape partition weights. + ones = array_ops.fill( + array_ops.expand_dims( + array_ops.rank(partitioned_result[p]) - 1, 0), 1) + bcast_weights_shape = array_ops.concat( + [array_ops.shape(partitioned_weight[p]), ones], 0) + orig_weights_shape = partitioned_weight[p].get_shape() + partitioned_weight[p] = array_ops.reshape(partitioned_weight[p], + bcast_weights_shape) + if partitioned_result[p].get_shape().ndims is not None: + partitioned_weight[p].set_shape( + orig_weights_shape.concatenate([ + 1 + for _ in range(partitioned_result[p].get_shape().ndims - + 1) + ])) + partitioned_result[p] *= partitioned_weight[p] + partitioned_segment_ids = [] + for p in xrange(np): + if not ignore_weights: + # Partition segment_ids according to pindices. + p_segment_ids = array_ops.gather(segment_ids, pindices[p]) + # Number the p_segment_ids to meet segment_sum's requirements. Note + # that unique_p_segment_ids contains unique segment ids of this + # partiton and these ids' order is unchanged. + unique_p_segment_ids, unique_p_segment_idx = array_ops.unique( + p_segment_ids) + partitioned_segment_ids.append(unique_p_segment_ids) + # segment_sum this partition's result. + with ops.colocate_with(params[p]): + partitioned_result[p] = math_ops.segment_sum( + partitioned_result[p], unique_p_segment_idx) + else: + # When ignore weights, we need to get indexs of elements in idx and + # segment_ids. + _, exclude_idx = array_ops.setdiff1d(idx, pindices[p]) + all_idx = math_ops.range(array_ops.shape(idx)[0]) + _, include_idx = array_ops.setdiff1d(all_idx, exclude_idx) + # Gather segment_ids and idx according to indexs. + p_segment_ids = array_ops.gather(segment_ids, include_idx) + p_idx = array_ops.gather(idx, include_idx) + # Number the p_segment_ids, same as ignore_weights case above. + unique_p_segment_ids, unique_p_segment_idx = array_ops.unique( + p_segment_ids) + _, unique_p_idx_idx = array_ops.unique(p_idx) + partitioned_segment_ids.append(unique_p_segment_ids) + with ops.colocate_with(params[p]): + partitioned_result[p] = math_ops.sparse_segment_sum( + partitioned_result[p], unique_p_idx_idx, unique_p_segment_idx) + # Concat each partition's segment_ids and result for final segment_sum. + concat_segment_ids = array_ops.concat(partitioned_segment_ids, 0) + concat_partitioned_result = array_ops.concat(partitioned_result, 0) + return math_ops.unsorted_segment_sum( + concat_partitioned_result, + concat_segment_ids, + math_ops.reduce_max(concat_segment_ids) + 1, + name=name) |