diff options
Diffstat (limited to 'tensorflow/python/ops/embedding_ops.py')
-rw-r--r-- | tensorflow/python/ops/embedding_ops.py | 31 |
1 files changed, 19 insertions, 12 deletions
diff --git a/tensorflow/python/ops/embedding_ops.py b/tensorflow/python/ops/embedding_ops.py index bc64593d23..d9a377bab5 100644 --- a/tensorflow/python/ops/embedding_ops.py +++ b/tensorflow/python/ops/embedding_ops.py @@ -8,24 +8,31 @@ from tensorflow.python.ops import math_ops def embedding_lookup(params, ids, name=None): - """Return a tensor of embedding values by looking up "ids" in "params". + """Looks up `ids` in a list of embedding tensors. + + This function is used to perform parallel lookups on the list of + tensors in `params`. It is a generalization of + [`tf.gather()`](array_ops.md#gather), where `params` is interpreted + as a partition of a larger embedding tensor. + + If `len(params) > 1`, each element `id` of `ids` is partitioned between + the elements of `params` by computing `p = id % len(params)`, and is + then used to look up the slice `params[p][id // len(params), ...]`. + + The results of the lookup are then concatenated into a dense + tensor. The returned tensor has shape `shape(ids) + shape(params)[1:]`. Args: - params: List of tensors of the same shape. A single tensor is - treated as a singleton list. - ids: Tensor of integers containing the ids to be looked up in - 'params'. Let P be len(params). If P > 1, then the ids are - partitioned by id % P, and we do separate lookups in params[p] - for 0 <= p < P, and then stitch the results back together into - a single result tensor. - name: Optional name for the op. + params: A list of tensors with the same shape and type. + ids: A `Tensor` with type `int32` containing the ids to be looked + up in `params`. + name: A name for the operation (optional). Returns: - A tensor of shape ids.shape + params[0].shape[1:] containing the - values params[i % P][i] for each i in ids. + A `Tensor` with the same type as the tensors in `params`. Raises: - ValueError: if some parameters are invalid. + ValueError: If `params` is empty. """ if not isinstance(params, list): params = [params] |