aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/python/ops/embedding_ops.py
diff options
context:
space:
mode:
authorGravatar A. Unique TensorFlower <gardener@tensorflow.org>2017-01-30 13:43:35 -0800
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2017-01-30 14:09:08 -0800
commit79a93ac627b9af8ae84a874ce248fe42aac8de36 (patch)
treef0c82451dcf766c5e7ec04a30199c06f59c42752 /tensorflow/python/ops/embedding_ops.py
parent696dad03880e9b3367b6c4b4c3903d6aa723d7e5 (diff)
Support partitioned embedding lookup for resource variables.
Change: 146034474
Diffstat (limited to 'tensorflow/python/ops/embedding_ops.py')
-rw-r--r--tensorflow/python/ops/embedding_ops.py29
1 files changed, 18 insertions, 11 deletions
diff --git a/tensorflow/python/ops/embedding_ops.py b/tensorflow/python/ops/embedding_ops.py
index 80507024d2..1e02cc6c48 100644
--- a/tensorflow/python/ops/embedding_ops.py
+++ b/tensorflow/python/ops/embedding_ops.py
@@ -33,6 +33,14 @@ from tensorflow.python.ops import variables
from tensorflow.python.platform import tf_logging as logging
+def _do_gather(params, ids, validate_indices=True, name=None):
+ """Deals with doing gather differently for resource variables."""
+ if isinstance(params, resource_variable_ops.ResourceVariable):
+ return params.sparse_read(ids, name=name)
+ return array_ops.gather(
+ params, ids, name=name, validate_indices=validate_indices)
+
+
def embedding_lookup(params, ids, partition_strategy="mod", name=None,
validate_indices=True, max_norm=None):
"""Looks up `ids` in a list of embedding tensors.
@@ -100,16 +108,15 @@ def embedding_lookup(params, ids, partition_strategy="mod", name=None,
return x
with ops.name_scope(name, "embedding_lookup", params + [ids]) as name:
np = len(params) # Number of partitions
- params = ops.convert_n_to_tensor_or_indexed_slices(params, name="params")
+ # Preserve the resource variable status to avoid accidental dense reads.
+ if not any(isinstance(p, resource_variable_ops.ResourceVariable)
+ for p in params):
+ params = ops.convert_n_to_tensor_or_indexed_slices(params, name="params")
if np == 1:
with ops.colocate_with(params[0]):
- # TODO(apassos): implement the sharded version as well.
- if isinstance(params[0], resource_variable_ops.ResourceVariable):
- ret = params[0].sparse_read(ids, name=name)
- else:
- ret = array_ops.gather(params[0], ids, name=name,
- validate_indices=validate_indices)
- return maybe_normalize(ret)
+ return maybe_normalize(
+ _do_gather(
+ params[0], ids, validate_indices=validate_indices, name=name))
else:
ids = ops.convert_to_tensor(ids, name="ids")
flat_ids = array_ops.reshape(ids, [-1])
@@ -169,9 +176,9 @@ def embedding_lookup(params, ids, partition_strategy="mod", name=None,
partitioned_result = []
for p in xrange(np):
with ops.colocate_with(params[p]):
- partitioned_result.append(array_ops.gather(
- params[p], gather_ids[p],
- validate_indices=validate_indices))
+ partitioned_result.append(
+ _do_gather(params[p], gather_ids[p],
+ validate_indices=validate_indices))
# Stitch these back together
ret = data_flow_ops.dynamic_stitch(pindices, partitioned_result,
name=name)