diff options
author | Derek Murray <mrry@google.com> | 2015-12-16 17:54:29 -0800 |
---|---|---|
committer | Vijay Vasudevan <vrv@google.com> | 2015-12-16 17:54:29 -0800 |
commit | f2eae4b3d27a4dc6d1f591f55a50fb3e1984a287 (patch) | |
tree | b85d8ee9d9bff6b4aadfc19955b1aa0a5d4e7b62 /tensorflow/python/kernel_tests/embedding_ops_test.py | |
parent | ee4f4409fe5d994a0d01a4c04441847218527ce0 (diff) |
Fix the gradient of `tf.gather()` when the indices are > 1-D.
The gradient function was previously generating an invalid
IndexedSlices, whereby `IndexedSlices.indices` tensor was not a
vector. This change reshapes the indices and gradient so that they can
correctly be interpreted as an IndexedSlices and applied to the
embedding variable.
Added a multi-dimensional gradient test in embedding_ops_test.py.
Fixes #505. Partially addresses #464.
Change: 110364370
Diffstat (limited to 'tensorflow/python/kernel_tests/embedding_ops_test.py')
-rw-r--r-- | tensorflow/python/kernel_tests/embedding_ops_test.py | 35 |
1 files changed, 18 insertions, 17 deletions
diff --git a/tensorflow/python/kernel_tests/embedding_ops_test.py b/tensorflow/python/kernel_tests/embedding_ops_test.py index 9a4cd14eb1..cdadff1567 100644 --- a/tensorflow/python/kernel_tests/embedding_ops_test.py +++ b/tensorflow/python/kernel_tests/embedding_ops_test.py @@ -325,25 +325,26 @@ class EmbeddingLookupTest(tf.test.TestCase): def testGradientsEmbeddingLookup(self): vocab_size = 9 - num_ids = 5 + num_ids = 10 id_vals = list(np.random.randint(vocab_size, size=num_ids)) tf.logging.vlog(1, id_vals) - for num_shards in [1, 3]: - with self.test_session(): - ids = tf.constant(id_vals, dtype=tf.int32) - x, params, _ = _EmbeddingParams( - num_shards, vocab_size, shape=[2]) - y = tf.nn.embedding_lookup(x, ids) - y_shape = [num_ids] + list(params[_PName(0) + ":0"].shape[1:]) - x_name = [_PName(i) for i in range(num_shards)] - x_init_value = [params[x_n + ":0"] for x_n in x_name] - x_shape = [i.shape for i in x_init_value] - err = tf.test.compute_gradient_error(x, - x_shape, - y, - y_shape, - x_init_value=x_init_value) - self.assertLess(err, 1e-4) + for ids_shape in [(10,), (2, 5)]: + for num_shards in [1, 3]: + with self.test_session(): + ids = tf.constant(id_vals, shape=ids_shape, dtype=tf.int32) + x, params, _ = _EmbeddingParams( + num_shards, vocab_size, shape=[2]) + y = tf.nn.embedding_lookup(x, ids) + y_shape = [num_ids] + list(params[_PName(0) + ":0"].shape[1:]) + x_name = [_PName(i) for i in range(num_shards)] + x_init_value = [params[x_n + ":0"] for x_n in x_name] + x_shape = [i.shape for i in x_init_value] + err = tf.test.compute_gradient_error(x, + x_shape, + y, + y_shape, + x_init_value=x_init_value) + self.assertLess(err, 1e-4) def testGradientsEmbeddingLookupWithComputedParams(self): vocab_size = 9 |