aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/contrib/layers
diff options
context:
space:
mode:
authorGravatar candy.dc <dingchen.mail@gmail.com>2018-07-26 11:36:30 +0800
committerGravatar candy.dc <dingchen.mail@gmail.com>2018-07-26 11:36:30 +0800
commit9bab0c89c4ffeeb780e7a3dc415ab888164b9b00 (patch)
treeffeebac847f5d1ba5ca551b86533490d337a1193 /tensorflow/contrib/layers
parent5d92abe1e426b85ef549a8ce811628d708f7d914 (diff)
fix: No need to convert to tensor when using ResourceVariable in embedding_lookup,
because ResourceVariable support ResourceGather OP.
Diffstat (limited to 'tensorflow/contrib/layers')
-rw-r--r--tensorflow/contrib/layers/python/layers/embedding_ops.py7
1 files changed, 4 insertions, 3 deletions
diff --git a/tensorflow/contrib/layers/python/layers/embedding_ops.py b/tensorflow/contrib/layers/python/layers/embedding_ops.py
index 60e1d85ea9..897aed527d 100644
--- a/tensorflow/contrib/layers/python/layers/embedding_ops.py
+++ b/tensorflow/contrib/layers/python/layers/embedding_ops.py
@@ -112,9 +112,10 @@ def safe_embedding_lookup_sparse(embedding_weights,
dtype = sparse_weights.dtype if sparse_weights is not None else None
if isinstance(embedding_weights, variables.PartitionedVariable):
embedding_weights = list(embedding_weights)
- embedding_weights = [
- ops.convert_to_tensor(w, dtype=dtype) for w in embedding_weights
- ]
+ if not isinstance(embedding_weights[0], resource_variable_ops.ResourceVariable):
+ embedding_weights = [
+ ops.convert_to_tensor(w, dtype=dtype) for w in embedding_weights
+ ]
contrib_tensor_util.assert_same_float_dtype(embedding_weights +
[sparse_weights])