aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/contrib/layers
diff options
context:
space:
mode:
authorGravatar A. Unique TensorFlower <gardener@tensorflow.org>2018-05-21 16:13:28 -0700
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2018-05-21 16:16:04 -0700
commitd67a99406f986021df9a0f8a99ff6eaf801dcb25 (patch)
tree3858c006d6638aa273d861b6f5c4ce4fb02d8cee /tensorflow/contrib/layers
parenta4afd46dc55c060b7f5132f0d088702169f864e4 (diff)
Expose partition_strategy option in embedding_lookup_unique
PiperOrigin-RevId: 197477853
Diffstat (limited to 'tensorflow/contrib/layers')
-rw-r--r--tensorflow/contrib/layers/python/layers/embedding_ops.py8
1 files changed, 6 insertions, 2 deletions
diff --git a/tensorflow/contrib/layers/python/layers/embedding_ops.py b/tensorflow/contrib/layers/python/layers/embedding_ops.py
index 49c3faf3b7..60e1d85ea9 100644
--- a/tensorflow/contrib/layers/python/layers/embedding_ops.py
+++ b/tensorflow/contrib/layers/python/layers/embedding_ops.py
@@ -458,7 +458,7 @@ def scattered_embedding_lookup_sparse(params,
return embeddings
-def embedding_lookup_unique(params, ids, name=None):
+def embedding_lookup_unique(params, ids, partition_strategy="mod", name=None):
"""Version of embedding_lookup that avoids duplicate lookups.
This can save communication in the case of repeated ids.
@@ -470,6 +470,9 @@ def embedding_lookup_unique(params, ids, name=None):
`PartitionedVariable`. Shape `[index, d1, d2, ...]`.
ids: A one-dimensional `Tensor` with type `int32` or `int64` containing
the ids to be looked up in `params`. Shape `[ids1, ids2, ...]`.
+ partition_strategy: A string specifying the partitioning strategy, relevant
+ if `len(params) > 1`. Currently `"div"` and `"mod"` are supported. Default
+ is `"mod"`.
name: A name for this operation (optional).
Returns:
@@ -485,7 +488,8 @@ def embedding_lookup_unique(params, ids, name=None):
ids_flat = array_ops.reshape(
ids, math_ops.reduce_prod(shape, keepdims=True))
unique_ids, idx = array_ops.unique(ids_flat)
- unique_embeddings = embedding_ops.embedding_lookup(params, unique_ids)
+ unique_embeddings = embedding_ops.embedding_lookup(params, unique_ids,
+ partition_strategy)
embeds_flat = array_ops.gather(unique_embeddings, idx)
embed_shape = array_ops.concat(
[shape, array_ops.shape(unique_embeddings)[1:]], 0)