aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/python/kernel_tests/embedding_ops_test.py
diff options
context:
space:
mode:
authorGravatar A. Unique TensorFlower <gardener@tensorflow.org>2017-05-29 18:51:43 -0700
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2017-05-29 18:55:02 -0700
commit822d64f0c699528ba41386cc60dcfbeba825c3f8 (patch)
treec2c3d78e9c07d6a80f737ac1c09d2b32d024dbbf /tensorflow/python/kernel_tests/embedding_ops_test.py
parent8cad6b824ea0cc8fee6138654912eb9b6a9933a6 (diff)
Fix embedding_lookup() bug where normalization did not work with ids of rank != 1.
PiperOrigin-RevId: 157422220
Diffstat (limited to 'tensorflow/python/kernel_tests/embedding_ops_test.py')
-rw-r--r--tensorflow/python/kernel_tests/embedding_ops_test.py25
1 files changed, 25 insertions, 0 deletions
diff --git a/tensorflow/python/kernel_tests/embedding_ops_test.py b/tensorflow/python/kernel_tests/embedding_ops_test.py
index 2bd21fb01d..057da9d7af 100644
--- a/tensorflow/python/kernel_tests/embedding_ops_test.py
+++ b/tensorflow/python/kernel_tests/embedding_ops_test.py
@@ -547,6 +547,31 @@ class EmbeddingLookupTest(test.TestCase):
sharded = embedding_ops.embedding_lookup(split_params, ids).eval()
self.assertAllEqual(simple, sharded)
+ def testHigherRankMaxNorm(self):
+ np.random.seed(8)
+ with self.test_session():
+ for params_shape in (12,), (6, 3):
+ params = 2 * np.ones(params_shape)
+ params_norm = params / np.sqrt(
+ np.sum(params*params, tuple(range(params.ndim)[1:]), keepdims=True))
+ for ids_shape in (), (3), (4, 3), (2, 3, 4):
+ ids = np.random.randint(
+ params.shape[0], size=np.prod(ids_shape, dtype=np.int64)).reshape(
+ ids_shape)
+ # Compare nonsharded to gather
+ simple = embedding_ops.embedding_lookup(
+ params, ids, max_norm=1.0).eval()
+ self.assertAllEqual(simple, array_ops.gather(params_norm, ids).eval())
+ # Run a few random sharded versions
+ for procs in 1, 2, 3:
+ stride = procs * math_ops.range(params.shape[0] // procs)
+ split_params = [
+ array_ops.gather(params, stride + p) for p in xrange(procs)
+ ]
+ sharded = embedding_ops.embedding_lookup(
+ split_params, ids, max_norm=1.0).eval()
+ self.assertAllEqual(simple, sharded)
+
class EmbeddingLookupSparseTest(test.TestCase):