diff options
author | 2017-01-30 13:43:35 -0800 | |
---|---|---|
committer | 2017-01-30 14:09:08 -0800 | |
commit | 79a93ac627b9af8ae84a874ce248fe42aac8de36 (patch) | |
tree | f0c82451dcf766c5e7ec04a30199c06f59c42752 /tensorflow/python/ops/embedding_ops.py | |
parent | 696dad03880e9b3367b6c4b4c3903d6aa723d7e5 (diff) |
Support partitioned embedding lookup for resource variables.
Change: 146034474
Diffstat (limited to 'tensorflow/python/ops/embedding_ops.py')
-rw-r--r-- | tensorflow/python/ops/embedding_ops.py | 29 |
1 files changed, 18 insertions, 11 deletions
diff --git a/tensorflow/python/ops/embedding_ops.py b/tensorflow/python/ops/embedding_ops.py index 80507024d2..1e02cc6c48 100644 --- a/tensorflow/python/ops/embedding_ops.py +++ b/tensorflow/python/ops/embedding_ops.py @@ -33,6 +33,14 @@ from tensorflow.python.ops import variables from tensorflow.python.platform import tf_logging as logging +def _do_gather(params, ids, validate_indices=True, name=None): + """Deals with doing gather differently for resource variables.""" + if isinstance(params, resource_variable_ops.ResourceVariable): + return params.sparse_read(ids, name=name) + return array_ops.gather( + params, ids, name=name, validate_indices=validate_indices) + + def embedding_lookup(params, ids, partition_strategy="mod", name=None, validate_indices=True, max_norm=None): """Looks up `ids` in a list of embedding tensors. @@ -100,16 +108,15 @@ def embedding_lookup(params, ids, partition_strategy="mod", name=None, return x with ops.name_scope(name, "embedding_lookup", params + [ids]) as name: np = len(params) # Number of partitions - params = ops.convert_n_to_tensor_or_indexed_slices(params, name="params") + # Preserve the resource variable status to avoid accidental dense reads. + if not any(isinstance(p, resource_variable_ops.ResourceVariable) + for p in params): + params = ops.convert_n_to_tensor_or_indexed_slices(params, name="params") if np == 1: with ops.colocate_with(params[0]): - # TODO(apassos): implement the sharded version as well. - if isinstance(params[0], resource_variable_ops.ResourceVariable): - ret = params[0].sparse_read(ids, name=name) - else: - ret = array_ops.gather(params[0], ids, name=name, - validate_indices=validate_indices) - return maybe_normalize(ret) + return maybe_normalize( + _do_gather( + params[0], ids, validate_indices=validate_indices, name=name)) else: ids = ops.convert_to_tensor(ids, name="ids") flat_ids = array_ops.reshape(ids, [-1]) @@ -169,9 +176,9 @@ def embedding_lookup(params, ids, partition_strategy="mod", name=None, partitioned_result = [] for p in xrange(np): with ops.colocate_with(params[p]): - partitioned_result.append(array_ops.gather( - params[p], gather_ids[p], - validate_indices=validate_indices)) + partitioned_result.append( + _do_gather(params[p], gather_ids[p], + validate_indices=validate_indices)) # Stitch these back together ret = data_flow_ops.dynamic_stitch(pindices, partitioned_result, name=name) |