diff options
author | 2015-12-17 16:39:27 -0800 | |
---|---|---|
committer | 2015-12-17 16:39:27 -0800 | |
commit | f0f6e33a5e51863712707133d8b56e31d69c4d7c (patch) | |
tree | 833cff5a244714d15a229bf17c6776699a03425d /tensorflow/python/kernel_tests/embedding_ops_test.py | |
parent | 320952da3ef3ac167c0c5d79f5a1e70b81447618 (diff) |
Improve shape inference of embedding lookup sparse when using weights.
Change: 110498004
Diffstat (limited to 'tensorflow/python/kernel_tests/embedding_ops_test.py')
-rw-r--r-- | tensorflow/python/kernel_tests/embedding_ops_test.py | 20 |
1 files changed, 20 insertions, 0 deletions
diff --git a/tensorflow/python/kernel_tests/embedding_ops_test.py b/tensorflow/python/kernel_tests/embedding_ops_test.py index cdadff1567..b30b13a252 100644 --- a/tensorflow/python/kernel_tests/embedding_ops_test.py +++ b/tensorflow/python/kernel_tests/embedding_ops_test.py @@ -445,6 +445,7 @@ class EmbeddingLookupSparseTest(tf.test.TestCase): vocab_size = 13 batch_size = 10 param_shape = [2, 5] + expected_lookup_result_shape = [None] + param_shape sp_ids, sp_weights, ids, weights, vals_per_batch_entry = ( self._RandomIdsAndWeights(batch_size, vocab_size)) @@ -467,6 +468,10 @@ class EmbeddingLookupSparseTest(tf.test.TestCase): embedding_sum = tf.nn.embedding_lookup_sparse( p, sp_ids, None if ignore_weights else sp_weights, combiner=combiner) + + self.assertEqual(embedding_sum.get_shape().as_list(), + expected_lookup_result_shape) + tf_embedding_sum = embedding_sum.eval(feed_dict=feed_dict) np_embedding_sum, np_weight_sum = _EmbeddingResult( @@ -508,6 +513,21 @@ class EmbeddingLookupSparseTest(tf.test.TestCase): x_init_value=x_init_value) self.assertLess(err, 1e-5 if dtype == tf.float64 else 2e-3) + def testIncompatibleShapes(self): + with self.test_session(): + x, _, _ = _EmbeddingParams(1, 10, dtype=tf.float32) + sp_ids = tf.SparseTensor( + tf.constant([[0, 0], [0, 1], [1, 0]], tf.int64), + tf.constant([0, 1, 2], tf.int32), + tf.constant([2, 2], tf.int64)) + sp_weights = tf.SparseTensor( + tf.constant([[0, 0], [0, 1]], tf.int64), + tf.constant([12.0, 5.0], tf.float32), + tf.constant([1, 2], tf.int64)) + + with self.assertRaises(ValueError): + tf.nn.embedding_lookup_sparse(x, sp_ids, sp_weights, combiner="mean") + class DynamicStitchOpTest(tf.test.TestCase): |