aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/python/kernel_tests/embedding_ops_test.py
diff options
context:
space:
mode:
authorGravatar A. Unique TensorFlower <gardener@tensorflow.org>2016-11-16 07:01:55 -0800
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2016-11-16 07:08:43 -0800
commit12dcf29550cac2d2d1d69d828ac03594804afd41 (patch)
tree08869e4ad7f87a0271542a06f2f6bcbc106ad238 /tensorflow/python/kernel_tests/embedding_ops_test.py
parent9003342ea06eb68c5f4def6d8e20c5dee5a295f1 (diff)
Adds max_norm option to embedding_lookup (and upstream functions).
Change: 139325873
Diffstat (limited to 'tensorflow/python/kernel_tests/embedding_ops_test.py')
-rw-r--r--tensorflow/python/kernel_tests/embedding_ops_test.py20
1 files changed, 20 insertions, 0 deletions
diff --git a/tensorflow/python/kernel_tests/embedding_ops_test.py b/tensorflow/python/kernel_tests/embedding_ops_test.py
index 67199cecdd..09a31b147b 100644
--- a/tensorflow/python/kernel_tests/embedding_ops_test.py
+++ b/tensorflow/python/kernel_tests/embedding_ops_test.py
@@ -228,6 +228,26 @@ class EmbeddingLookupTest(tf.test.TestCase):
self.assertAllEqual(np_result, tf_result)
self.assertShapeEqual(np_result, embedding)
+ def testMaxNorm(self):
+ with self.test_session():
+ embeddings = tf.constant([[2.0]])
+
+ ids = tf.constant([0], dtype=tf.int32)
+ embedding = tf.nn.embedding_lookup([embeddings], ids, max_norm=1.0)
+
+ self.assertAllEqual(embedding.eval(), [[1.0]])
+
+ def testMaxNormNontrivial(self):
+ with self.test_session():
+ embeddings = tf.constant([[2.0, 4.0], [3.0, 1.0]])
+
+ ids = tf.constant([0, 1], dtype=tf.int32)
+ embedding = tf.nn.embedding_lookup([embeddings], ids, max_norm=2.0)
+
+ norms = tf.sqrt(tf.reduce_sum(embeddings * embeddings, axis=1))
+ normalized = embeddings/tf.stack([norms, norms], axis=1)
+ self.assertAllEqual(embedding.eval(), 2 * normalized.eval())
+
def testSimpleShardedPartitionedVariable(self):
with self.test_session() as sess:
num_shards = 2