aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/python/kernel_tests/embedding_ops_test.py
diff options
context:
space:
mode:
authorGravatar Derek Murray <mrry@google.com>2015-12-16 17:54:29 -0800
committerGravatar Vijay Vasudevan <vrv@google.com>2015-12-16 17:54:29 -0800
commitf2eae4b3d27a4dc6d1f591f55a50fb3e1984a287 (patch)
treeb85d8ee9d9bff6b4aadfc19955b1aa0a5d4e7b62 /tensorflow/python/kernel_tests/embedding_ops_test.py
parentee4f4409fe5d994a0d01a4c04441847218527ce0 (diff)
Fix the gradient of `tf.gather()` when the indices are > 1-D.
The gradient function was previously generating an invalid IndexedSlices, whereby `IndexedSlices.indices` tensor was not a vector. This change reshapes the indices and gradient so that they can correctly be interpreted as an IndexedSlices and applied to the embedding variable. Added a multi-dimensional gradient test in embedding_ops_test.py. Fixes #505. Partially addresses #464. Change: 110364370
Diffstat (limited to 'tensorflow/python/kernel_tests/embedding_ops_test.py')
-rw-r--r--tensorflow/python/kernel_tests/embedding_ops_test.py35
1 files changed, 18 insertions, 17 deletions
diff --git a/tensorflow/python/kernel_tests/embedding_ops_test.py b/tensorflow/python/kernel_tests/embedding_ops_test.py
index 9a4cd14eb1..cdadff1567 100644
--- a/tensorflow/python/kernel_tests/embedding_ops_test.py
+++ b/tensorflow/python/kernel_tests/embedding_ops_test.py
@@ -325,25 +325,26 @@ class EmbeddingLookupTest(tf.test.TestCase):
def testGradientsEmbeddingLookup(self):
vocab_size = 9
- num_ids = 5
+ num_ids = 10
id_vals = list(np.random.randint(vocab_size, size=num_ids))
tf.logging.vlog(1, id_vals)
- for num_shards in [1, 3]:
- with self.test_session():
- ids = tf.constant(id_vals, dtype=tf.int32)
- x, params, _ = _EmbeddingParams(
- num_shards, vocab_size, shape=[2])
- y = tf.nn.embedding_lookup(x, ids)
- y_shape = [num_ids] + list(params[_PName(0) + ":0"].shape[1:])
- x_name = [_PName(i) for i in range(num_shards)]
- x_init_value = [params[x_n + ":0"] for x_n in x_name]
- x_shape = [i.shape for i in x_init_value]
- err = tf.test.compute_gradient_error(x,
- x_shape,
- y,
- y_shape,
- x_init_value=x_init_value)
- self.assertLess(err, 1e-4)
+ for ids_shape in [(10,), (2, 5)]:
+ for num_shards in [1, 3]:
+ with self.test_session():
+ ids = tf.constant(id_vals, shape=ids_shape, dtype=tf.int32)
+ x, params, _ = _EmbeddingParams(
+ num_shards, vocab_size, shape=[2])
+ y = tf.nn.embedding_lookup(x, ids)
+ y_shape = [num_ids] + list(params[_PName(0) + ":0"].shape[1:])
+ x_name = [_PName(i) for i in range(num_shards)]
+ x_init_value = [params[x_n + ":0"] for x_n in x_name]
+ x_shape = [i.shape for i in x_init_value]
+ err = tf.test.compute_gradient_error(x,
+ x_shape,
+ y,
+ y_shape,
+ x_init_value=x_init_value)
+ self.assertLess(err, 1e-4)
def testGradientsEmbeddingLookupWithComputedParams(self):
vocab_size = 9