diff options
author | A. Unique TensorFlower <gardener@tensorflow.org> | 2016-11-16 07:01:55 -0800 |
---|---|---|
committer | TensorFlower Gardener <gardener@tensorflow.org> | 2016-11-16 07:08:43 -0800 |
commit | 12dcf29550cac2d2d1d69d828ac03594804afd41 (patch) | |
tree | 08869e4ad7f87a0271542a06f2f6bcbc106ad238 /tensorflow/python/kernel_tests/embedding_ops_test.py | |
parent | 9003342ea06eb68c5f4def6d8e20c5dee5a295f1 (diff) |
Adds max_norm option to embedding_lookup (and upstream functions).
Change: 139325873
Diffstat (limited to 'tensorflow/python/kernel_tests/embedding_ops_test.py')
-rw-r--r-- | tensorflow/python/kernel_tests/embedding_ops_test.py | 20 |
1 files changed, 20 insertions, 0 deletions
diff --git a/tensorflow/python/kernel_tests/embedding_ops_test.py b/tensorflow/python/kernel_tests/embedding_ops_test.py index 67199cecdd..09a31b147b 100644 --- a/tensorflow/python/kernel_tests/embedding_ops_test.py +++ b/tensorflow/python/kernel_tests/embedding_ops_test.py @@ -228,6 +228,26 @@ class EmbeddingLookupTest(tf.test.TestCase): self.assertAllEqual(np_result, tf_result) self.assertShapeEqual(np_result, embedding) + def testMaxNorm(self): + with self.test_session(): + embeddings = tf.constant([[2.0]]) + + ids = tf.constant([0], dtype=tf.int32) + embedding = tf.nn.embedding_lookup([embeddings], ids, max_norm=1.0) + + self.assertAllEqual(embedding.eval(), [[1.0]]) + + def testMaxNormNontrivial(self): + with self.test_session(): + embeddings = tf.constant([[2.0, 4.0], [3.0, 1.0]]) + + ids = tf.constant([0, 1], dtype=tf.int32) + embedding = tf.nn.embedding_lookup([embeddings], ids, max_norm=2.0) + + norms = tf.sqrt(tf.reduce_sum(embeddings * embeddings, axis=1)) + normalized = embeddings/tf.stack([norms, norms], axis=1) + self.assertAllEqual(embedding.eval(), 2 * normalized.eval()) + def testSimpleShardedPartitionedVariable(self): with self.test_session() as sess: num_shards = 2 |