diff options
author | A. Unique TensorFlower <gardener@tensorflow.org> | 2018-05-21 16:13:28 -0700 |
---|---|---|
committer | TensorFlower Gardener <gardener@tensorflow.org> | 2018-05-21 16:16:04 -0700 |
commit | d67a99406f986021df9a0f8a99ff6eaf801dcb25 (patch) | |
tree | 3858c006d6638aa273d861b6f5c4ce4fb02d8cee /tensorflow/contrib/layers | |
parent | a4afd46dc55c060b7f5132f0d088702169f864e4 (diff) |
Expose partition_strategy option in embedding_lookup_unique
PiperOrigin-RevId: 197477853
Diffstat (limited to 'tensorflow/contrib/layers')
-rw-r--r-- | tensorflow/contrib/layers/python/layers/embedding_ops.py | 8 |
1 files changed, 6 insertions, 2 deletions
diff --git a/tensorflow/contrib/layers/python/layers/embedding_ops.py b/tensorflow/contrib/layers/python/layers/embedding_ops.py index 49c3faf3b7..60e1d85ea9 100644 --- a/tensorflow/contrib/layers/python/layers/embedding_ops.py +++ b/tensorflow/contrib/layers/python/layers/embedding_ops.py @@ -458,7 +458,7 @@ def scattered_embedding_lookup_sparse(params, return embeddings -def embedding_lookup_unique(params, ids, name=None): +def embedding_lookup_unique(params, ids, partition_strategy="mod", name=None): """Version of embedding_lookup that avoids duplicate lookups. This can save communication in the case of repeated ids. @@ -470,6 +470,9 @@ def embedding_lookup_unique(params, ids, name=None): `PartitionedVariable`. Shape `[index, d1, d2, ...]`. ids: A one-dimensional `Tensor` with type `int32` or `int64` containing the ids to be looked up in `params`. Shape `[ids1, ids2, ...]`. + partition_strategy: A string specifying the partitioning strategy, relevant + if `len(params) > 1`. Currently `"div"` and `"mod"` are supported. Default + is `"mod"`. name: A name for this operation (optional). Returns: @@ -485,7 +488,8 @@ def embedding_lookup_unique(params, ids, name=None): ids_flat = array_ops.reshape( ids, math_ops.reduce_prod(shape, keepdims=True)) unique_ids, idx = array_ops.unique(ids_flat) - unique_embeddings = embedding_ops.embedding_lookup(params, unique_ids) + unique_embeddings = embedding_ops.embedding_lookup(params, unique_ids, + partition_strategy) embeds_flat = array_ops.gather(unique_embeddings, idx) embed_shape = array_ops.concat( [shape, array_ops.shape(unique_embeddings)[1:]], 0) |