from __future__ import division
from __future__ import print_function
+from six.moves import xrange # pylint: disable=redefined-builtin
from tensorflow.contrib.framework.python.framework import tensor_util as contrib_tensor_util
from tensorflow.contrib.layers.python.ops import sparse_feature_cross_op
+from tensorflow.python.framework import constant_op
from tensorflow.python.framework import dtypes
from tensorflow.python.framework import ops
from tensorflow.python.framework import sparse_tensor
from tensorflow.python.framework import tensor_shape
from tensorflow.python.ops import array_ops
+from tensorflow.python.ops import clip_ops
from tensorflow.python.ops import control_flow_ops
+from tensorflow.python.ops import data_flow_ops
from tensorflow.python.ops import embedding_ops
from tensorflow.python.ops import math_ops
+from tensorflow.python.ops import resource_variable_ops
from tensorflow.python.ops import sparse_ops
from tensorflow.python.ops import variables
from tensorflow.python.platform import tf_logging as logging
__all__ = [
"safe_embedding_lookup_sparse", "scattered_embedding_lookup",
- "scattered_embedding_lookup_sparse", "embedding_lookup_unique"
+ "scattered_embedding_lookup_sparse", "embedding_lookup_unique",
+ "embedding_lookup_sparse_with_distributed_aggregation"
return math_ops.unsorted_segment_sum(embeddings, segment_ids,
+def embedding_lookup_sparse_with_distributed_aggregation(
+ params,
+ sp_ids,
+ sp_weights,
+ partition_strategy="mod",
+ name=None,
+ combiner=None,
+ max_norm=None):
+ """Computes embeddings for the given ids and weights.
+ Embeddings belonging to same param are aggregated on that device first. This
+ op is intended to decrease data transmission and improve parallelism. See
+ `tf.nn.embedding_lookup_sparse` for the functionality and example of this op.
+ Args:
+ params: A single tensor representing the complete embedding tensor,
+ or a list of P tensors all of same shape except for the first dimension,
+ representing sharded embedding tensors. Alternatively, a
+ `PartitionedVariable`, created by partitioning along dimension 0. Each
+ element must be appropriately sized for the given `partition_strategy`.
+ sp_ids: N x M SparseTensor of int64 ids (typically from FeatureValueToId),
+ where N is typically batch size and M is arbitrary.
+ sp_weights: either a SparseTensor of float / double weights, or None to
+ indicate all weights should be taken to be 1. If specified, sp_weights
+ must have exactly the same shape and indices as sp_ids.
+ partition_strategy: A string specifying the partitioning strategy, relevant
+ if `len(params) > 1`. Currently `"div"` and `"mod"` are supported. Default
+ is `"mod"`. See `tf.nn.embedding_lookup` for more details.
+ name: Optional name for the op.
+ combiner: A string specifying the reduction op. Currently "mean", "sqrtn"
+ and "sum" are supported.
+ "sum" computes the weighted sum of the embedding results for each row.
+ "mean" is the weighted sum divided by the total weight.
+ "sqrtn" is the weighted sum divided by the square root of the sum of the
+ squares of the weights.
+ max_norm: If not None, each embedding is normalized to have l2 norm equal
+ to max_norm before combining.
+ Returns:
+ A dense tensor representing the combined embeddings for the
+ sparse ids. For each row in the dense tensor represented by sp_ids, the op
+ looks up the embeddings for all ids in that row, multiplies them by the
+ corresponding weight, and combines these embeddings as specified.
+ Raises:
+ TypeError: If sp_ids is not a SparseTensor, or if sp_weights is neither
+ None nor SparseTensor.
+ ValueError: If combiner is not one of {"mean", "sqrtn", "sum"}.
+ """
+ if combiner is None:
+ logging.warn("The default value of combiner will change from \"mean\" "
+ "to \"sqrtn\" after 2016/11/01.")
+ combiner = "mean"
+ if combiner not in ("mean", "sqrtn", "sum"):
+ raise ValueError("combiner must be one of 'mean', 'sqrtn' or 'sum'")
+ if isinstance(params, variables.PartitionedVariable):
+ params = list(params) # Iterate to get the underlying Variables.
+ if not isinstance(params, list):
+ params = [params]
+ if not isinstance(sp_ids, sparse_tensor.SparseTensor):
+ raise TypeError("sp_ids must be SparseTensor")
+ ignore_weights = sp_weights is None
+ if not ignore_weights:
+ if not isinstance(sp_weights, sparse_tensor.SparseTensor):
+ raise TypeError("sp_weights must be either None or SparseTensor")
+ sp_ids.values.get_shape().assert_is_compatible_with(
+ sp_weights.values.get_shape())
+ sp_ids.indices.get_shape().assert_is_compatible_with(
+ sp_weights.indices.get_shape())
+ sp_ids.dense_shape.get_shape().assert_is_compatible_with(
+ sp_weights.dense_shape.get_shape())
+ # TODO(yleon): Add enhanced node assertions to verify that sp_ids and
+ # sp_weights have equal indices and shapes.
+ with ops.name_scope(name, "embedding_lookup_sparse",
+ params + [sp_ids]) as name:
+ segment_ids = sp_ids.indices[:, 0]
+ if segment_ids.dtype != dtypes.int32:
+ segment_ids = math_ops.cast(segment_ids, dtypes.int32)
+ ids = sp_ids.values
+ if ignore_weights:
+ ids, idx = array_ops.unique(ids)
+ else:
+ idx = None
+ weights = None if ignore_weights else sp_weights.values
+ embeddings = _embedding_lookup_with_distributed_aggregation(
+ params,
+ ids,
+ partition_strategy=partition_strategy,
+ max_norm=max_norm,
+ weights=weights,
+ idx=idx,
+ segment_ids=segment_ids)
+ # Set weights to all one if ignore weights.
+ if ignore_weights:
+ weights = array_ops.fill([array_ops.shape(segment_ids)[0]], 1)
+ if weights.dtype != embeddings.dtype:
+ weights = math_ops.cast(weights, embeddings.dtype)
+ # Reshape weights.
+ ones = array_ops.fill(
+ array_ops.expand_dims(array_ops.rank(embeddings) - 1, 0), 1)
+ bcast_weights_shape = array_ops.concat([array_ops.shape(weights), ones], 0)
+ orig_weights_shape = weights.get_shape()
+ weights = array_ops.reshape(weights, bcast_weights_shape)
+ if embeddings.get_shape().ndims is not None:
+ weights.set_shape(
+ orig_weights_shape.concatenate(
+ [1 for _ in range(embeddings.get_shape().ndims - 1)]))
+ if combiner == "mean":
+ weight_sum = math_ops.segment_sum(weights, segment_ids)
+ embeddings = math_ops.div(embeddings, weight_sum)
+ elif combiner == "sqrtn":
+ weights_squared = math_ops.pow(weights, 2)
+ weight_sum = math_ops.segment_sum(weights_squared, segment_ids)
+ weight_sum_sqrt = math_ops.sqrt(weight_sum)
+ embeddings = math_ops.div(embeddings, weight_sum_sqrt)
+ elif combiner != "sum":
+ assert False, "Unrecognized combiner"
+ return embeddings
+def _do_gather(params, ids, validate_indices=True, name=None):
+ """Deals with doing gather differently for resource variables."""
+ if isinstance(params, resource_variable_ops.ResourceVariable):
+ return params.sparse_read(ids, name=name)
+ return array_ops.gather(
+ params, ids, name=name, validate_indices=validate_indices)
+def _embedding_lookup_with_distributed_aggregation(params,
+ ids,
+ partition_strategy="mod",
+ name=None,
+ validate_indices=True,
+ max_norm=None,
+ weights=None,
+ idx=None,
+ segment_ids=None):
+ """Lookup helper for embedding_lookup_sparse_with_distributed_aggregation."""
+ if params is None or params == []: # pylint: disable=g-explicit-bool-comparison
+ raise ValueError("Need at least one param")
+ if isinstance(params, variables.PartitionedVariable):
+ params = list(params) # Iterate to get the underlying Variables.
+ if not isinstance(params, list):
+ params = [params]
+ def maybe_normalize(x):
+ if max_norm is not None:
+ if x.get_shape().ndims is not None:
+ ndims = x.get_shape().ndims
+ else:
+ ndims = array_ops.size(array_ops.shape(x))
+ return clip_ops.clip_by_norm(x, max_norm, axes=list(range(1, ndims)))
+ return x
+ with ops.name_scope(name, "embedding_lookup_with_distributed_aggregation",
+ params + [ids]) as name:
+ np = len(params) # Number of partitions
+ # Preserve the resource variable status to avoid accidental dense reads.
+ if not any(
+ isinstance(p, resource_variable_ops.ResourceVariable) for p in params):
+ params = ops.convert_n_to_tensor_or_indexed_slices(params, name="params")
+ if np == 1:
+ with ops.colocate_with(params[0]):
+ ret = maybe_normalize(
+ _do_gather(params[0], ids, validate_indices=validate_indices))
+ ignore_weights = weights is None
+ if not ignore_weights:
+ if weights.dtype != ret.dtype:
+ weights = math_ops.cast(weights, ret.dtype)
+ # Reshape to allow broadcast
+ ones = array_ops.fill(
+ array_ops.expand_dims(array_ops.rank(ret) - 1, 0), 1)
+ bcast_weights_shape = array_ops.concat(
+ [array_ops.shape(weights), ones], 0)
+ orig_weights_shape = weights.get_shape()
+ weights = array_ops.reshape(weights, bcast_weights_shape)
+ # Set weights shape after reshape
+ if ret.get_shape().ndims is not None:
+ weights.set_shape(
+ orig_weights_shape.concatenate(
+ [1 for _ in range(ret.get_shape().ndims - 1)]))
+ ret *= weights
+ return math_ops.segment_sum(ret, segment_ids, name=name)
+ else:
+ return math_ops.sparse_segment_sum(ret, idx, segment_ids, name=name)
+ else:
+ ids = ops.convert_to_tensor(ids, name="ids")
+ flat_ids = array_ops.reshape(ids, [-1])
+ original_indices = math_ops.range(array_ops.size(flat_ids))
+ # Create p_assignments and set new_ids depending on the strategy.
+ if partition_strategy == "mod":
+ p_assignments = flat_ids % np
+ new_ids = flat_ids // np
+ elif partition_strategy == "div":
+ # Compute num_total_ids as the sum of dim-0 of params, then assign to
+ # partitions based on a constant number of ids per partition. Optimize
+ # if we already know the full shape statically.
+ dim_0_size = params[0].get_shape()[0]
+ for p in xrange(1, np):
+ dim_0_size += params[p].get_shape()[0]
+ if dim_0_size.value:
+ num_total_ids = constant_op.constant(dim_0_size.value, flat_ids.dtype)
+ else:
+ dim_0_sizes = []
+ for p in xrange(np):
+ if params[p].get_shape()[0].value is not None:
+ dim_0_sizes.append(params[p].get_shape()[0].value)
+ else:
+ with ops.colocate_with(params[p]):
+ dim_0_sizes.append(array_ops.shape(params[p])[0])
+ num_total_ids = math_ops.reduce_sum(
+ math_ops.cast(array_ops.stack(dim_0_sizes), flat_ids.dtype))
+ ids_per_partition = num_total_ids // np
+ extras = num_total_ids % np
+ p_assignments = math_ops.maximum(flat_ids // (ids_per_partition + 1), (
+ flat_ids - extras) // ids_per_partition)
+ # Emulate a conditional using a boolean indicator tensor
+ is_in_first_extras_partitions = math_ops.cast(p_assignments < extras,
+ flat_ids.dtype)
+ new_ids = (is_in_first_extras_partitions * (flat_ids %
+ (ids_per_partition + 1)) +
+ (1 - is_in_first_extras_partitions) * (
+ (flat_ids - extras) % ids_per_partition))
+ else:
+ raise ValueError("Unrecognized partition strategy: " +
+ partition_strategy)
+ # Cast partition assignments to int32 for use in dynamic_partition.
+ # There really should not be more than 2^32 partitions.
+ p_assignments = math_ops.cast(p_assignments, dtypes.int32)
+ # Partition list of ids based on assignments into np separate lists
+ gather_ids = data_flow_ops.dynamic_partition(new_ids, p_assignments, np)
+ # Similarly, partition the original indices.
+ pindices = data_flow_ops.dynamic_partition(original_indices,
+ p_assignments, np)
+ # Do np separate lookups, finding embeddings for plist[p] in params[p]
+ partitioned_result = []
+ for p in xrange(np):
+ with ops.colocate_with(params[p]):
+ partitioned_result.append(
+ _do_gather(
+ params[p], gather_ids[p], validate_indices=validate_indices))
+ ignore_weights = weights is None
+ if not ignore_weights:
+ # Partition weights according to pindices.
+ partitioned_weight = []
+ for p in xrange(np):
+ partitioned_weight.append(array_ops.gather(weights, pindices[p]))
+ # Reshape each partition result.
+ element_shape = params[0].get_shape()[1:]
+ for p in params[1:]:
+ element_shape = element_shape.merge_with(p.get_shape()[1:])
+ if element_shape.is_fully_defined():
+ for p in xrange(np):
+ with ops.colocate_with(params[p]):
+ partitioned_result[p] = array_ops.reshape(
+ partitioned_result[p],
+ array_ops.concat([array_ops.shape(pindices[p]), element_shape],
+ 0))
+ else:
+ with ops.colocate_with(params[0]):
+ params_shape = array_ops.shape(params[0])
+ for p in xrange(np):
+ with ops.colocate_with(params[p]):
+ partitioned_result[p] = array_ops.reshape(
+ partitioned_result[p],
+ array_ops.concat([
+ array_ops.shape(pindices[p]), array_ops.slice(
+ params_shape, [1], [-1])
+ ], 0))
+ # Normalize each partition result.
+ for p in xrange(np):
+ with ops.colocate_with(params[p]):
+ partitioned_result[p] = maybe_normalize(partitioned_result[p])
+ if not ignore_weights:
+ # Multiply each partition result with partition weights.
+ for p in xrange(np):
+ with ops.colocate_with(params[p]):
+ if partitioned_weight[p].dtype != partitioned_result[p].dtype:
+ partitioned_weight[p] = math_ops.cast(partitioned_weight[p],
+ partitioned_result[p].dtype)
+ # Reshape partition weights.
+ ones = array_ops.fill(
+ array_ops.expand_dims(
+ array_ops.rank(partitioned_result[p]) - 1, 0), 1)
+ bcast_weights_shape = array_ops.concat(
+ [array_ops.shape(partitioned_weight[p]), ones], 0)
+ orig_weights_shape = partitioned_weight[p].get_shape()
+ partitioned_weight[p] = array_ops.reshape(partitioned_weight[p],
+ bcast_weights_shape)
+ if partitioned_result[p].get_shape().ndims is not None:
+ partitioned_weight[p].set_shape(
+ orig_weights_shape.concatenate([
+ 1
+ for _ in range(partitioned_result[p].get_shape().ndims -
+ 1)
+ ]))
+ partitioned_result[p] *= partitioned_weight[p]
+ partitioned_segment_ids = []
+ for p in xrange(np):
+ if not ignore_weights:
+ # Partition segment_ids according to pindices.
+ p_segment_ids = array_ops.gather(segment_ids, pindices[p])
+ # Number the p_segment_ids to meet segment_sum's requirements. Note
+ # that unique_p_segment_ids contains unique segment ids of this
+ # partiton and these ids' order is unchanged.
+ unique_p_segment_ids, unique_p_segment_idx = array_ops.unique(
+ p_segment_ids)
+ partitioned_segment_ids.append(unique_p_segment_ids)
+ # segment_sum this partition's result.
+ with ops.colocate_with(params[p]):
+ partitioned_result[p] = math_ops.segment_sum(
+ partitioned_result[p], unique_p_segment_idx)
+ else:
+ # When ignore weights, we need to get indexs of elements in idx and
+ # segment_ids.
+ _, exclude_idx = array_ops.setdiff1d(idx, pindices[p])
+ all_idx = math_ops.range(array_ops.shape(idx)[0])
+ _, include_idx = array_ops.setdiff1d(all_idx, exclude_idx)
+ # Gather segment_ids and idx according to indexs.
+ p_segment_ids = array_ops.gather(segment_ids, include_idx)
+ p_idx = array_ops.gather(idx, include_idx)
+ # Number the p_segment_ids, same as ignore_weights case above.
+ unique_p_segment_ids, unique_p_segment_idx = array_ops.unique(
+ p_segment_ids)
+ _, unique_p_idx_idx = array_ops.unique(p_idx)
+ partitioned_segment_ids.append(unique_p_segment_ids)
+ with ops.colocate_with(params[p]):
+ partitioned_result[p] = math_ops.sparse_segment_sum(
+ partitioned_result[p], unique_p_idx_idx, unique_p_segment_idx)
+ # Concat each partition's segment_ids and result for final segment_sum.
+ concat_segment_ids = array_ops.concat(partitioned_segment_ids, 0)
+ concat_partitioned_result = array_ops.concat(partitioned_result, 0)
+ return math_ops.unsorted_segment_sum(
+ concat_partitioned_result,
+ concat_segment_ids,
+ math_ops.reduce_max(concat_segment_ids) + 1,
+ name=name)