aboutsummaryrefslogtreecommitdiffhomepage
diff options
context:
space:
mode:
-rw-r--r--tensorflow/python/kernel_tests/embedding_ops_test.py25
-rw-r--r--tensorflow/python/ops/embedding_ops.py25
2 files changed, 43 insertions, 7 deletions
diff --git a/tensorflow/python/kernel_tests/embedding_ops_test.py b/tensorflow/python/kernel_tests/embedding_ops_test.py
index 2bd21fb01d..057da9d7af 100644
--- a/tensorflow/python/kernel_tests/embedding_ops_test.py
+++ b/tensorflow/python/kernel_tests/embedding_ops_test.py
@@ -547,6 +547,31 @@ class EmbeddingLookupTest(test.TestCase):
sharded = embedding_ops.embedding_lookup(split_params, ids).eval()
self.assertAllEqual(simple, sharded)
+ def testHigherRankMaxNorm(self):
+ np.random.seed(8)
+ with self.test_session():
+ for params_shape in (12,), (6, 3):
+ params = 2 * np.ones(params_shape)
+ params_norm = params / np.sqrt(
+ np.sum(params*params, tuple(range(params.ndim)[1:]), keepdims=True))
+ for ids_shape in (), (3), (4, 3), (2, 3, 4):
+ ids = np.random.randint(
+ params.shape[0], size=np.prod(ids_shape, dtype=np.int64)).reshape(
+ ids_shape)
+ # Compare nonsharded to gather
+ simple = embedding_ops.embedding_lookup(
+ params, ids, max_norm=1.0).eval()
+ self.assertAllEqual(simple, array_ops.gather(params_norm, ids).eval())
+ # Run a few random sharded versions
+ for procs in 1, 2, 3:
+ stride = procs * math_ops.range(params.shape[0] // procs)
+ split_params = [
+ array_ops.gather(params, stride + p) for p in xrange(procs)
+ ]
+ sharded = embedding_ops.embedding_lookup(
+ split_params, ids, max_norm=1.0).eval()
+ self.assertAllEqual(simple, sharded)
+
class EmbeddingLookupSparseTest(test.TestCase):
diff --git a/tensorflow/python/ops/embedding_ops.py b/tensorflow/python/ops/embedding_ops.py
index 315e7d4b43..6930f9af05 100644
--- a/tensorflow/python/ops/embedding_ops.py
+++ b/tensorflow/python/ops/embedding_ops.py
@@ -103,14 +103,25 @@ def embedding_lookup(params, ids, partition_strategy="mod", name=None,
params = list(params) # Iterate to get the underlying Variables.
if not isinstance(params, list):
params = [params]
+
def maybe_normalize(x):
- if max_norm is not None:
- if x.get_shape().ndims is not None:
- ndims = x.get_shape().ndims
- else:
- ndims = array_ops.size(array_ops.shape(x))
- return clip_ops.clip_by_norm(x, max_norm, axes=list(range(1, ndims)))
- return x
+ """Normalizes the embeddings in x if max_norm is not None."""
+ if max_norm is None:
+ return x
+ static = True
+ ids_rank = ops.convert_to_tensor(ids).get_shape().ndims
+ if ids_rank is None:
+ ids_rank = array_ops.rank(ids)
+ static = False
+ x_rank = x.get_shape().ndims
+ if x_rank is None:
+ x_rank = array_ops.rank(x)
+ static = False
+ return clip_ops.clip_by_norm(
+ x, max_norm,
+ axes=list(range(ids_rank, x_rank)) if static
+ else math_ops.range(ids_rank, x_rank))
+
with ops.name_scope(name, "embedding_lookup", params + [ids]) as name:
np = len(params) # Number of partitions
# Preserve the resource variable status to avoid accidental dense reads.