aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/python/ops/embedding_ops.py
diff options
context:
space:
mode:
Diffstat (limited to 'tensorflow/python/ops/embedding_ops.py')
-rw-r--r--tensorflow/python/ops/embedding_ops.py31
1 files changed, 19 insertions, 12 deletions
diff --git a/tensorflow/python/ops/embedding_ops.py b/tensorflow/python/ops/embedding_ops.py
index bc64593d23..d9a377bab5 100644
--- a/tensorflow/python/ops/embedding_ops.py
+++ b/tensorflow/python/ops/embedding_ops.py
@@ -8,24 +8,31 @@ from tensorflow.python.ops import math_ops
def embedding_lookup(params, ids, name=None):
- """Return a tensor of embedding values by looking up "ids" in "params".
+ """Looks up `ids` in a list of embedding tensors.
+
+ This function is used to perform parallel lookups on the list of
+ tensors in `params`. It is a generalization of
+ [`tf.gather()`](array_ops.md#gather), where `params` is interpreted
+ as a partition of a larger embedding tensor.
+
+ If `len(params) > 1`, each element `id` of `ids` is partitioned between
+ the elements of `params` by computing `p = id % len(params)`, and is
+ then used to look up the slice `params[p][id // len(params), ...]`.
+
+ The results of the lookup are then concatenated into a dense
+ tensor. The returned tensor has shape `shape(ids) + shape(params)[1:]`.
Args:
- params: List of tensors of the same shape. A single tensor is
- treated as a singleton list.
- ids: Tensor of integers containing the ids to be looked up in
- 'params'. Let P be len(params). If P > 1, then the ids are
- partitioned by id % P, and we do separate lookups in params[p]
- for 0 <= p < P, and then stitch the results back together into
- a single result tensor.
- name: Optional name for the op.
+ params: A list of tensors with the same shape and type.
+ ids: A `Tensor` with type `int32` containing the ids to be looked
+ up in `params`.
+ name: A name for the operation (optional).
Returns:
- A tensor of shape ids.shape + params[0].shape[1:] containing the
- values params[i % P][i] for each i in ids.
+ A `Tensor` with the same type as the tensors in `params`.
Raises:
- ValueError: if some parameters are invalid.
+ ValueError: If `params` is empty.
"""
if not isinstance(params, list):
params = [params]