diff options
author | 2017-05-29 18:51:43 -0700 | |
---|---|---|
committer | 2017-05-29 18:55:02 -0700 | |
commit | 822d64f0c699528ba41386cc60dcfbeba825c3f8 (patch) | |
tree | c2c3d78e9c07d6a80f737ac1c09d2b32d024dbbf /tensorflow/python/kernel_tests/embedding_ops_test.py | |
parent | 8cad6b824ea0cc8fee6138654912eb9b6a9933a6 (diff) |
Fix embedding_lookup() bug where normalization did not work with ids of rank != 1.
PiperOrigin-RevId: 157422220
Diffstat (limited to 'tensorflow/python/kernel_tests/embedding_ops_test.py')
-rw-r--r-- | tensorflow/python/kernel_tests/embedding_ops_test.py | 25 |
1 files changed, 25 insertions, 0 deletions
diff --git a/tensorflow/python/kernel_tests/embedding_ops_test.py b/tensorflow/python/kernel_tests/embedding_ops_test.py index 2bd21fb01d..057da9d7af 100644 --- a/tensorflow/python/kernel_tests/embedding_ops_test.py +++ b/tensorflow/python/kernel_tests/embedding_ops_test.py @@ -547,6 +547,31 @@ class EmbeddingLookupTest(test.TestCase): sharded = embedding_ops.embedding_lookup(split_params, ids).eval() self.assertAllEqual(simple, sharded) + def testHigherRankMaxNorm(self): + np.random.seed(8) + with self.test_session(): + for params_shape in (12,), (6, 3): + params = 2 * np.ones(params_shape) + params_norm = params / np.sqrt( + np.sum(params*params, tuple(range(params.ndim)[1:]), keepdims=True)) + for ids_shape in (), (3), (4, 3), (2, 3, 4): + ids = np.random.randint( + params.shape[0], size=np.prod(ids_shape, dtype=np.int64)).reshape( + ids_shape) + # Compare nonsharded to gather + simple = embedding_ops.embedding_lookup( + params, ids, max_norm=1.0).eval() + self.assertAllEqual(simple, array_ops.gather(params_norm, ids).eval()) + # Run a few random sharded versions + for procs in 1, 2, 3: + stride = procs * math_ops.range(params.shape[0] // procs) + split_params = [ + array_ops.gather(params, stride + p) for p in xrange(procs) + ] + sharded = embedding_ops.embedding_lookup( + split_params, ids, max_norm=1.0).eval() + self.assertAllEqual(simple, sharded) + class EmbeddingLookupSparseTest(test.TestCase): |