aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/contrib/layers/python/layers/embedding_ops_test.py
diff options
context:
space:
mode:
Diffstat (limited to 'tensorflow/contrib/layers/python/layers/embedding_ops_test.py')
-rw-r--r--tensorflow/contrib/layers/python/layers/embedding_ops_test.py291
1 files changed, 255 insertions, 36 deletions
diff --git a/tensorflow/contrib/layers/python/layers/embedding_ops_test.py b/tensorflow/contrib/layers/python/layers/embedding_ops_test.py
index dfa8067f27..bf25144982 100644
--- a/tensorflow/contrib/layers/python/layers/embedding_ops_test.py
+++ b/tensorflow/contrib/layers/python/layers/embedding_ops_test.py
@@ -31,10 +31,13 @@ from tensorflow.python.framework import dtypes
from tensorflow.python.framework import errors_impl
from tensorflow.python.framework import random_seed
from tensorflow.python.framework import sparse_tensor as sparse_tensor_lib
+from tensorflow.python.ops import array_ops
+from tensorflow.python.ops import gradient_checker
from tensorflow.python.ops import init_ops
from tensorflow.python.ops import math_ops
from tensorflow.python.ops import partitioned_variables
from tensorflow.python.platform import test
+from tensorflow.python.util import compat
class SafeEmbeddingLookupSparseTest(test.TestCase):
@@ -143,8 +146,8 @@ class SafeEmbeddingLookupSparseTest(test.TestCase):
self.assertAllClose(
embedding_lookup_result,
[(embedding_weights[0][0] + embedding_weights[0][1]) / 2.0, [0] * 4,
- [0] * 4, embedding_weights[0][2],
- (embedding_weights[0][0] + embedding_weights[0][1]) / 2.0])
+ [0] * 4, embedding_weights[0][2], (
+ embedding_weights[0][0] + embedding_weights[0][1]) / 2.0])
def test_safe_embedding_lookup_sparse_partitioned(self):
with self.test_session():
@@ -169,8 +172,8 @@ class SafeEmbeddingLookupSparseTest(test.TestCase):
self.assertRaises(ValueError, embedding_ops.safe_embedding_lookup_sparse,
embedding_weights, sparse_ids)
embedding_weights = [
- constant_op.constant(
- w, dtype=dtypes.float64) for w in embedding_weights
+ constant_op.constant(w, dtype=dtypes.float64)
+ for w in embedding_weights
]
self.assertRaises(ValueError, embedding_ops.safe_embedding_lookup_sparse,
embedding_weights, sparse_ids, sparse_weights)
@@ -183,11 +186,10 @@ class SafeEmbeddingLookupSparseTest(test.TestCase):
embedding_lookup_result = (embedding_ops.safe_embedding_lookup_sparse(
embedding_weights, sparse_ids, sparse_weights).eval())
- self.assertAllClose(
- embedding_lookup_result,
- [[(1.0 * embedding_weights[0][0] + 2.0 * embedding_weights[0][1]) /
- 3.0, [0] * 4, [0] * 4],
- [embedding_weights[0][2], [0] * 4, [0] * 4]])
+ self.assertAllClose(embedding_lookup_result, [[
+ (1.0 * embedding_weights[0][0] + 2.0 * embedding_weights[0][1]) / 3.0,
+ [0] * 4, [0] * 4
+ ], [embedding_weights[0][2], [0] * 4, [0] * 4]])
def test_safe_embedding_lookup_sparse_3d_return_special_vector(self):
with self.test_session():
@@ -213,14 +215,13 @@ class SafeEmbeddingLookupSparseTest(test.TestCase):
embedding_lookup_result = (embedding_ops.safe_embedding_lookup_sparse(
embedding_weights, sparse_ids, None).eval())
- self.assertAllClose(
- embedding_lookup_result,
- [[(embedding_weights[0][0] + embedding_weights[0][1]) / 2.0, [0] * 4,
- [0] * 4], [
- embedding_weights[0][2],
- (embedding_weights[0][0] + embedding_weights[0][1]) / 2.0,
- [0] * 4
- ]])
+ self.assertAllClose(embedding_lookup_result, [[(
+ embedding_weights[0][0] + embedding_weights[0][1]) / 2.0, [0] * 4, [
+ 0
+ ] * 4], [
+ embedding_weights[0][2],
+ (embedding_weights[0][0] + embedding_weights[0][1]) / 2.0, [0] * 4
+ ]])
def test_safe_embedding_lookup_sparse_3d_partitioned(self):
with self.test_session():
@@ -231,13 +232,12 @@ class SafeEmbeddingLookupSparseTest(test.TestCase):
embedding_weights, sparse_ids, None).eval())
embedding_weights = list(itertools.chain(*embedding_weights))
- self.assertAllClose(embedding_lookup_result,
- [[(embedding_weights[0] + embedding_weights[1]) / 2.0,
- [0] * 4, [0] * 4], [
- embedding_weights[2],
- (embedding_weights[0] + embedding_weights[1]) /
- 2.0, [0] * 4
- ]])
+ self.assertAllClose(embedding_lookup_result, [[
+ (embedding_weights[0] + embedding_weights[1]) / 2.0, [0] * 4, [0] * 4
+ ], [
+ embedding_weights[2],
+ (embedding_weights[0] + embedding_weights[1]) / 2.0, [0] * 4
+ ]])
def test_safe_embedding_lookup_sparse_3d_partitioned_inconsistent_weights(
self):
@@ -249,8 +249,8 @@ class SafeEmbeddingLookupSparseTest(test.TestCase):
self.assertRaises(ValueError, embedding_ops.safe_embedding_lookup_sparse,
embedding_weights, sparse_ids)
embedding_weights = [
- constant_op.constant(
- w, dtype=dtypes.float64) for w in embedding_weights
+ constant_op.constant(w, dtype=dtypes.float64)
+ for w in embedding_weights
]
self.assertRaises(ValueError, embedding_ops.safe_embedding_lookup_sparse,
embedding_weights, sparse_ids, sparse_weights)
@@ -299,8 +299,8 @@ class ScatteredEmbeddingLookupTest(test.TestCase):
self.assertAllEqual(embedding_lookup_result[0],
embedding_lookup_result[1])
# Different embedding expected for different value.
- embedding_diff = np.min((embedding_lookup_result[2] -
- embedding_lookup_result[0])**2)
+ embedding_diff = np.min(
+ (embedding_lookup_result[2] - embedding_lookup_result[0])**2)
self.assertGreater(embedding_diff, 0)
def test_scattered_embedding_coverage(self):
@@ -318,8 +318,8 @@ class ScatteredEmbeddingLookupTest(test.TestCase):
def test_scattered_embedding_multi_dimension(self):
with self.test_session():
embedding_weights = self._random_weights()
- values = constant_op.constant(
- [["foo", "bar", "bar"], ["bar", "bar", "foo"]])
+ values = constant_op.constant([["foo", "bar", "bar"],
+ ["bar", "bar", "foo"]])
embedding_lookup_result = embedding_ops.scattered_embedding_lookup(
embedding_weights, values, dimension=10).eval()
@@ -338,8 +338,8 @@ class ScatteredEmbeddingLookupTest(test.TestCase):
embedding_lookup_result = (
embedding_ops.scattered_embedding_lookup_sparse(
- embedding_weights, sparse_tensor, dimension=5, combiner="mean")
- .eval())
+ embedding_weights, sparse_tensor, dimension=5,
+ combiner="mean").eval())
self.assertAllEqual(embedding_lookup_result.shape, [5, 5])
# Same non-zero embedding for the empty rows filled with a default value.
@@ -431,8 +431,8 @@ class SampledScatteredEmbeddingLookupTest(test.TestCase):
def test_hashed_embedding_multi_dimension(self):
with self.test_session():
embedding_weights = self._random_weights()
- values = constant_op.constant(
- [["foo", "bar", "bar"], ["bar", "bar", "foo"]])
+ values = constant_op.constant([["foo", "bar", "bar"],
+ ["bar", "bar", "foo"]])
sampled_candidates = constant_op.constant(
[[[1, 3, 4, 6], [1, 7, 8, 9], [1, 7, 8, 9]],
[[1, 7, 8, 9], [1, 7, 8, 9], [1, 3, 4, 6]]])
@@ -489,8 +489,8 @@ class SampledScatteredEmbeddingLookupSparseTest(test.TestCase):
result = embedding_ops._sampled_scattered_embedding_lookup_sparse(
params, sp_values, dimension=5, hash_key=self._hash_key)
- self.assertAllClose(result.eval(), [[0., 0., 0., 0., 0.],
- [.3, .2, .2, .3, .1],
+ self.assertAllClose(result.eval(), [[0., 0., 0., 0.,
+ 0.], [.3, .2, .2, .3, .1],
[0., 0., 0., 0., 0.]])
def test_output_values_with_sampled_candidates(self):
@@ -563,5 +563,224 @@ class SampledScatteredEmbeddingLookupSparseTest(test.TestCase):
self.assertAllClose(result.eval(), result_abc.eval())
+def _PName(param_id):
+ return "p" + str(param_id)
+
+
+def _EmbeddingParams(num_shards,
+ vocab_size,
+ dtype=dtypes.float32,
+ shape=None,
+ use_shapeless_placeholder=False):
+ p = []
+ params = {}
+ feed_dict = {}
+ if not shape:
+ shape = [10]
+ for i in range(num_shards):
+ shard_shape = [vocab_size // num_shards] + shape
+ if i < vocab_size % num_shards: # Excess goes evenly on the first shards
+ shard_shape[0] += 1
+
+ param_name = _PName(i)
+
+ if use_shapeless_placeholder:
+ param = array_ops.placeholder(dtype, shape=None, name=param_name)
+ else:
+ param = constant_op.constant(
+ 1.0, shape=shard_shape, dtype=dtype, name=param_name)
+ p.append(param)
+ np_type = "f" if dtype == dtypes.float32 else "d"
+ val = (np.random.rand(*shard_shape).astype(np_type)) + 1
+ params[param_name + ":0"] = val
+ feed_dict[param.name] = val
+ return p, params, feed_dict
+
+
+def _EmbeddingResult(params,
+ id_vals,
+ num_shards,
+ vocab_size,
+ partition_strategy="mod",
+ weight_vals=None):
+ if weight_vals is None:
+ weight_vals = np.copy(id_vals)
+ weight_vals.fill(1)
+ values = []
+ weights = []
+ weights_squared = []
+ for ids, wts in zip(id_vals, weight_vals):
+ value_aggregation = None
+ weight_aggregation = None
+ squared_weight_aggregation = None
+ if isinstance(ids, compat.integral_types):
+ ids = [ids]
+ wts = [wts]
+ for i, weight_value in zip(ids, wts):
+ if partition_strategy == "mod":
+ val = np.copy(params[_PName(i % num_shards) + ":0"][
+ i // num_shards, :]) * weight_value
+ elif partition_strategy == "div":
+ ids_per_partition, extras = divmod(vocab_size, num_shards)
+ threshold = extras * (ids_per_partition + 1)
+ if i < threshold:
+ partition = i // (ids_per_partition + 1)
+ offset = i % (ids_per_partition + 1)
+ else:
+ partition = extras + (i - threshold) // ids_per_partition
+ offset = (i - threshold) % ids_per_partition
+ val = np.copy(
+ params[_PName(partition) + ":0"][offset, :]) * weight_value
+ else:
+ assert False
+ if value_aggregation is None:
+ assert weight_aggregation is None
+ assert squared_weight_aggregation is None
+ value_aggregation = val
+ weight_aggregation = weight_value
+ squared_weight_aggregation = weight_value * weight_value
+ else:
+ assert weight_aggregation is not None
+ assert squared_weight_aggregation is not None
+ value_aggregation += val
+ weight_aggregation += weight_value
+ squared_weight_aggregation += weight_value * weight_value
+ values.append(value_aggregation)
+ weights.append(weight_aggregation)
+ weights_squared.append(squared_weight_aggregation)
+ values = np.array(values).astype(np.float32)
+ weights = np.array(weights).astype(np.float32)
+ weights_squared = np.array(weights_squared).astype(np.float32)
+ return values, weights, weights_squared
+
+
+class EmbeddingLookupSparseWithDistributedAggregationTest(test.TestCase):
+
+ def _RandomIdsAndWeights(self, batch_size, vocab_size):
+ max_val_per_entry = 6
+ vals_per_batch_entry = np.random.randint(
+ 1, max_val_per_entry, size=batch_size)
+ num_vals = np.sum(vals_per_batch_entry)
+
+ ids = np.random.randint(vocab_size, size=num_vals)
+ weights = 1 + np.random.rand(num_vals)
+
+ indices = []
+ for batch_entry, num_val in enumerate(vals_per_batch_entry):
+ for val_index in range(num_val):
+ indices.append([batch_entry, val_index])
+
+ shape = [batch_size, max_val_per_entry]
+
+ sp_ids = sparse_tensor_lib.SparseTensor(
+ constant_op.constant(indices, dtypes.int64),
+ constant_op.constant(ids, dtypes.int32),
+ constant_op.constant(shape, dtypes.int64))
+ sp_weights = sparse_tensor_lib.SparseTensor(
+ constant_op.constant(indices, dtypes.int64),
+ constant_op.constant(weights, dtypes.float32),
+ constant_op.constant(shape, dtypes.int64))
+
+ return sp_ids, sp_weights, ids, weights, vals_per_batch_entry
+
+ def _GroupByBatchEntry(self, vals, vals_per_batch_entry):
+ grouped_vals = []
+ index = 0
+ for num_val in vals_per_batch_entry:
+ grouped_vals.append(list(vals[index:(index + num_val)]))
+ index += num_val
+ return grouped_vals
+
+ def testEmbeddingLookupSparse(self):
+ vocab_size = 13
+ batch_size = 10
+ param_shape = [2, 5]
+ expected_lookup_result_shape = [None] + param_shape
+
+ sp_ids, sp_weights, ids, weights, vals_per_batch_entry = (
+ self._RandomIdsAndWeights(batch_size, vocab_size))
+
+ grouped_ids = self._GroupByBatchEntry(ids, vals_per_batch_entry)
+ grouped_weights = self._GroupByBatchEntry(weights, vals_per_batch_entry)
+ grouped_ignored_weights = self._GroupByBatchEntry(
+ np.ones(np.sum(vals_per_batch_entry)), vals_per_batch_entry)
+
+ for num_shards, combiner, dtype, ignore_weights in itertools.product(
+ [1, 5], ["sum", "mean", "sqrtn"], [dtypes.float32,
+ dtypes.float64], [True, False]):
+
+ with self.test_session():
+ p, params, feed_dict = _EmbeddingParams(
+ num_shards, vocab_size, shape=param_shape, dtype=dtype)
+ embedding_sum = \
+ embedding_ops.embedding_lookup_sparse_with_distributed_aggregation(
+ p,
+ sp_ids,
+ None if ignore_weights else sp_weights,
+ combiner=combiner)
+
+ self.assertEqual(embedding_sum.get_shape().as_list(),
+ expected_lookup_result_shape)
+
+ tf_embedding_sum = embedding_sum.eval(feed_dict=feed_dict)
+
+ np_embedding_sum, np_weight_sum, np_weight_sq_sum = _EmbeddingResult(
+ params,
+ grouped_ids,
+ num_shards,
+ vocab_size,
+ weight_vals=grouped_ignored_weights
+ if ignore_weights else grouped_weights)
+ if combiner == "mean":
+ np_embedding_sum /= np.reshape(np_weight_sum, (batch_size, 1, 1))
+ if combiner == "sqrtn":
+ np_embedding_sum /= np.reshape(
+ np.sqrt(np_weight_sq_sum), (batch_size, 1, 1))
+ self.assertAllClose(np_embedding_sum, tf_embedding_sum)
+
+ def testGradientsEmbeddingLookupSparse(self):
+ vocab_size = 12
+ batch_size = 4
+ param_shape = [2, 3]
+ sp_ids, sp_weights, _, _, _ = (self._RandomIdsAndWeights(
+ batch_size, vocab_size))
+
+ for num_shards, combiner, dtype, ignore_weights in itertools.product(
+ [1, 3], ["sum", "mean", "sqrtn"], [dtypes.float32,
+ dtypes.float64], [True, False]):
+ with self.test_session():
+ x, params, _ = _EmbeddingParams(
+ num_shards, vocab_size, shape=param_shape, dtype=dtype)
+
+ y = embedding_ops.embedding_lookup_sparse_with_distributed_aggregation(
+ x,
+ sp_ids,
+ None if ignore_weights else sp_weights,
+ combiner=combiner)
+ x_name = [_PName(i) for i in range(num_shards)]
+ x_init_value = [params[x_n + ":0"] for x_n in x_name]
+ x_shape = [i.shape for i in x_init_value]
+ y_shape = [batch_size] + list(params[_PName(0) + ":0"].shape[1:])
+ err = gradient_checker.compute_gradient_error(
+ x, x_shape, y, y_shape, x_init_value=x_init_value)
+ self.assertLess(err, 1e-5 if dtype == dtypes.float64 else 2e-3)
+
+ def testIncompatibleShapes(self):
+ with self.test_session():
+ x, _, _ = _EmbeddingParams(1, 10, dtype=dtypes.float32)
+ sp_ids = sparse_tensor_lib.SparseTensor(
+ constant_op.constant([[0, 0], [0, 1], [1, 0]], dtypes.int64),
+ constant_op.constant([0, 1, 2], dtypes.int32),
+ constant_op.constant([2, 2], dtypes.int64))
+ sp_weights = sparse_tensor_lib.SparseTensor(
+ constant_op.constant([[0, 0], [0, 1]], dtypes.int64),
+ constant_op.constant([12.0, 5.0], dtypes.float32),
+ constant_op.constant([1, 2], dtypes.int64))
+
+ with self.assertRaises(ValueError):
+ embedding_ops.embedding_lookup_sparse_with_distributed_aggregation(
+ x, sp_ids, sp_weights, combiner="mean")
+
+
if __name__ == "__main__":
test.main()