aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/python/kernel_tests/embedding_ops_test.py
diff options
context:
space:
mode:
authorGravatar Yutaka Leon <yutaka.leon@gmail.com>2015-12-17 16:39:27 -0800
committerGravatar Vijay Vasudevan <vrv@google.com>2015-12-17 16:39:27 -0800
commitf0f6e33a5e51863712707133d8b56e31d69c4d7c (patch)
tree833cff5a244714d15a229bf17c6776699a03425d /tensorflow/python/kernel_tests/embedding_ops_test.py
parent320952da3ef3ac167c0c5d79f5a1e70b81447618 (diff)
Improve shape inference of embedding lookup sparse when using weights.
Change: 110498004
Diffstat (limited to 'tensorflow/python/kernel_tests/embedding_ops_test.py')
-rw-r--r--tensorflow/python/kernel_tests/embedding_ops_test.py20
1 files changed, 20 insertions, 0 deletions
diff --git a/tensorflow/python/kernel_tests/embedding_ops_test.py b/tensorflow/python/kernel_tests/embedding_ops_test.py
index cdadff1567..b30b13a252 100644
--- a/tensorflow/python/kernel_tests/embedding_ops_test.py
+++ b/tensorflow/python/kernel_tests/embedding_ops_test.py
@@ -445,6 +445,7 @@ class EmbeddingLookupSparseTest(tf.test.TestCase):
vocab_size = 13
batch_size = 10
param_shape = [2, 5]
+ expected_lookup_result_shape = [None] + param_shape
sp_ids, sp_weights, ids, weights, vals_per_batch_entry = (
self._RandomIdsAndWeights(batch_size, vocab_size))
@@ -467,6 +468,10 @@ class EmbeddingLookupSparseTest(tf.test.TestCase):
embedding_sum = tf.nn.embedding_lookup_sparse(
p, sp_ids, None if ignore_weights else sp_weights,
combiner=combiner)
+
+ self.assertEqual(embedding_sum.get_shape().as_list(),
+ expected_lookup_result_shape)
+
tf_embedding_sum = embedding_sum.eval(feed_dict=feed_dict)
np_embedding_sum, np_weight_sum = _EmbeddingResult(
@@ -508,6 +513,21 @@ class EmbeddingLookupSparseTest(tf.test.TestCase):
x_init_value=x_init_value)
self.assertLess(err, 1e-5 if dtype == tf.float64 else 2e-3)
+ def testIncompatibleShapes(self):
+ with self.test_session():
+ x, _, _ = _EmbeddingParams(1, 10, dtype=tf.float32)
+ sp_ids = tf.SparseTensor(
+ tf.constant([[0, 0], [0, 1], [1, 0]], tf.int64),
+ tf.constant([0, 1, 2], tf.int32),
+ tf.constant([2, 2], tf.int64))
+ sp_weights = tf.SparseTensor(
+ tf.constant([[0, 0], [0, 1]], tf.int64),
+ tf.constant([12.0, 5.0], tf.float32),
+ tf.constant([1, 2], tf.int64))
+
+ with self.assertRaises(ValueError):
+ tf.nn.embedding_lookup_sparse(x, sp_ids, sp_weights, combiner="mean")
+
class DynamicStitchOpTest(tf.test.TestCase):