diff options
Diffstat (limited to 'tensorflow/contrib/layers/python/layers/embedding_ops_test.py')
-rw-r--r-- | tensorflow/contrib/layers/python/layers/embedding_ops_test.py | 291 |
1 files changed, 255 insertions, 36 deletions
diff --git a/tensorflow/contrib/layers/python/layers/embedding_ops_test.py b/tensorflow/contrib/layers/python/layers/embedding_ops_test.py index dfa8067f27..bf25144982 100644 --- a/tensorflow/contrib/layers/python/layers/embedding_ops_test.py +++ b/tensorflow/contrib/layers/python/layers/embedding_ops_test.py @@ -31,10 +31,13 @@ from tensorflow.python.framework import dtypes from tensorflow.python.framework import errors_impl from tensorflow.python.framework import random_seed from tensorflow.python.framework import sparse_tensor as sparse_tensor_lib +from tensorflow.python.ops import array_ops +from tensorflow.python.ops import gradient_checker from tensorflow.python.ops import init_ops from tensorflow.python.ops import math_ops from tensorflow.python.ops import partitioned_variables from tensorflow.python.platform import test +from tensorflow.python.util import compat class SafeEmbeddingLookupSparseTest(test.TestCase): @@ -143,8 +146,8 @@ class SafeEmbeddingLookupSparseTest(test.TestCase): self.assertAllClose( embedding_lookup_result, [(embedding_weights[0][0] + embedding_weights[0][1]) / 2.0, [0] * 4, - [0] * 4, embedding_weights[0][2], - (embedding_weights[0][0] + embedding_weights[0][1]) / 2.0]) + [0] * 4, embedding_weights[0][2], ( + embedding_weights[0][0] + embedding_weights[0][1]) / 2.0]) def test_safe_embedding_lookup_sparse_partitioned(self): with self.test_session(): @@ -169,8 +172,8 @@ class SafeEmbeddingLookupSparseTest(test.TestCase): self.assertRaises(ValueError, embedding_ops.safe_embedding_lookup_sparse, embedding_weights, sparse_ids) embedding_weights = [ - constant_op.constant( - w, dtype=dtypes.float64) for w in embedding_weights + constant_op.constant(w, dtype=dtypes.float64) + for w in embedding_weights ] self.assertRaises(ValueError, embedding_ops.safe_embedding_lookup_sparse, embedding_weights, sparse_ids, sparse_weights) @@ -183,11 +186,10 @@ class SafeEmbeddingLookupSparseTest(test.TestCase): embedding_lookup_result = (embedding_ops.safe_embedding_lookup_sparse( embedding_weights, sparse_ids, sparse_weights).eval()) - self.assertAllClose( - embedding_lookup_result, - [[(1.0 * embedding_weights[0][0] + 2.0 * embedding_weights[0][1]) / - 3.0, [0] * 4, [0] * 4], - [embedding_weights[0][2], [0] * 4, [0] * 4]]) + self.assertAllClose(embedding_lookup_result, [[ + (1.0 * embedding_weights[0][0] + 2.0 * embedding_weights[0][1]) / 3.0, + [0] * 4, [0] * 4 + ], [embedding_weights[0][2], [0] * 4, [0] * 4]]) def test_safe_embedding_lookup_sparse_3d_return_special_vector(self): with self.test_session(): @@ -213,14 +215,13 @@ class SafeEmbeddingLookupSparseTest(test.TestCase): embedding_lookup_result = (embedding_ops.safe_embedding_lookup_sparse( embedding_weights, sparse_ids, None).eval()) - self.assertAllClose( - embedding_lookup_result, - [[(embedding_weights[0][0] + embedding_weights[0][1]) / 2.0, [0] * 4, - [0] * 4], [ - embedding_weights[0][2], - (embedding_weights[0][0] + embedding_weights[0][1]) / 2.0, - [0] * 4 - ]]) + self.assertAllClose(embedding_lookup_result, [[( + embedding_weights[0][0] + embedding_weights[0][1]) / 2.0, [0] * 4, [ + 0 + ] * 4], [ + embedding_weights[0][2], + (embedding_weights[0][0] + embedding_weights[0][1]) / 2.0, [0] * 4 + ]]) def test_safe_embedding_lookup_sparse_3d_partitioned(self): with self.test_session(): @@ -231,13 +232,12 @@ class SafeEmbeddingLookupSparseTest(test.TestCase): embedding_weights, sparse_ids, None).eval()) embedding_weights = list(itertools.chain(*embedding_weights)) - self.assertAllClose(embedding_lookup_result, - [[(embedding_weights[0] + embedding_weights[1]) / 2.0, - [0] * 4, [0] * 4], [ - embedding_weights[2], - (embedding_weights[0] + embedding_weights[1]) / - 2.0, [0] * 4 - ]]) + self.assertAllClose(embedding_lookup_result, [[ + (embedding_weights[0] + embedding_weights[1]) / 2.0, [0] * 4, [0] * 4 + ], [ + embedding_weights[2], + (embedding_weights[0] + embedding_weights[1]) / 2.0, [0] * 4 + ]]) def test_safe_embedding_lookup_sparse_3d_partitioned_inconsistent_weights( self): @@ -249,8 +249,8 @@ class SafeEmbeddingLookupSparseTest(test.TestCase): self.assertRaises(ValueError, embedding_ops.safe_embedding_lookup_sparse, embedding_weights, sparse_ids) embedding_weights = [ - constant_op.constant( - w, dtype=dtypes.float64) for w in embedding_weights + constant_op.constant(w, dtype=dtypes.float64) + for w in embedding_weights ] self.assertRaises(ValueError, embedding_ops.safe_embedding_lookup_sparse, embedding_weights, sparse_ids, sparse_weights) @@ -299,8 +299,8 @@ class ScatteredEmbeddingLookupTest(test.TestCase): self.assertAllEqual(embedding_lookup_result[0], embedding_lookup_result[1]) # Different embedding expected for different value. - embedding_diff = np.min((embedding_lookup_result[2] - - embedding_lookup_result[0])**2) + embedding_diff = np.min( + (embedding_lookup_result[2] - embedding_lookup_result[0])**2) self.assertGreater(embedding_diff, 0) def test_scattered_embedding_coverage(self): @@ -318,8 +318,8 @@ class ScatteredEmbeddingLookupTest(test.TestCase): def test_scattered_embedding_multi_dimension(self): with self.test_session(): embedding_weights = self._random_weights() - values = constant_op.constant( - [["foo", "bar", "bar"], ["bar", "bar", "foo"]]) + values = constant_op.constant([["foo", "bar", "bar"], + ["bar", "bar", "foo"]]) embedding_lookup_result = embedding_ops.scattered_embedding_lookup( embedding_weights, values, dimension=10).eval() @@ -338,8 +338,8 @@ class ScatteredEmbeddingLookupTest(test.TestCase): embedding_lookup_result = ( embedding_ops.scattered_embedding_lookup_sparse( - embedding_weights, sparse_tensor, dimension=5, combiner="mean") - .eval()) + embedding_weights, sparse_tensor, dimension=5, + combiner="mean").eval()) self.assertAllEqual(embedding_lookup_result.shape, [5, 5]) # Same non-zero embedding for the empty rows filled with a default value. @@ -431,8 +431,8 @@ class SampledScatteredEmbeddingLookupTest(test.TestCase): def test_hashed_embedding_multi_dimension(self): with self.test_session(): embedding_weights = self._random_weights() - values = constant_op.constant( - [["foo", "bar", "bar"], ["bar", "bar", "foo"]]) + values = constant_op.constant([["foo", "bar", "bar"], + ["bar", "bar", "foo"]]) sampled_candidates = constant_op.constant( [[[1, 3, 4, 6], [1, 7, 8, 9], [1, 7, 8, 9]], [[1, 7, 8, 9], [1, 7, 8, 9], [1, 3, 4, 6]]]) @@ -489,8 +489,8 @@ class SampledScatteredEmbeddingLookupSparseTest(test.TestCase): result = embedding_ops._sampled_scattered_embedding_lookup_sparse( params, sp_values, dimension=5, hash_key=self._hash_key) - self.assertAllClose(result.eval(), [[0., 0., 0., 0., 0.], - [.3, .2, .2, .3, .1], + self.assertAllClose(result.eval(), [[0., 0., 0., 0., + 0.], [.3, .2, .2, .3, .1], [0., 0., 0., 0., 0.]]) def test_output_values_with_sampled_candidates(self): @@ -563,5 +563,224 @@ class SampledScatteredEmbeddingLookupSparseTest(test.TestCase): self.assertAllClose(result.eval(), result_abc.eval()) +def _PName(param_id): + return "p" + str(param_id) + + +def _EmbeddingParams(num_shards, + vocab_size, + dtype=dtypes.float32, + shape=None, + use_shapeless_placeholder=False): + p = [] + params = {} + feed_dict = {} + if not shape: + shape = [10] + for i in range(num_shards): + shard_shape = [vocab_size // num_shards] + shape + if i < vocab_size % num_shards: # Excess goes evenly on the first shards + shard_shape[0] += 1 + + param_name = _PName(i) + + if use_shapeless_placeholder: + param = array_ops.placeholder(dtype, shape=None, name=param_name) + else: + param = constant_op.constant( + 1.0, shape=shard_shape, dtype=dtype, name=param_name) + p.append(param) + np_type = "f" if dtype == dtypes.float32 else "d" + val = (np.random.rand(*shard_shape).astype(np_type)) + 1 + params[param_name + ":0"] = val + feed_dict[param.name] = val + return p, params, feed_dict + + +def _EmbeddingResult(params, + id_vals, + num_shards, + vocab_size, + partition_strategy="mod", + weight_vals=None): + if weight_vals is None: + weight_vals = np.copy(id_vals) + weight_vals.fill(1) + values = [] + weights = [] + weights_squared = [] + for ids, wts in zip(id_vals, weight_vals): + value_aggregation = None + weight_aggregation = None + squared_weight_aggregation = None + if isinstance(ids, compat.integral_types): + ids = [ids] + wts = [wts] + for i, weight_value in zip(ids, wts): + if partition_strategy == "mod": + val = np.copy(params[_PName(i % num_shards) + ":0"][ + i // num_shards, :]) * weight_value + elif partition_strategy == "div": + ids_per_partition, extras = divmod(vocab_size, num_shards) + threshold = extras * (ids_per_partition + 1) + if i < threshold: + partition = i // (ids_per_partition + 1) + offset = i % (ids_per_partition + 1) + else: + partition = extras + (i - threshold) // ids_per_partition + offset = (i - threshold) % ids_per_partition + val = np.copy( + params[_PName(partition) + ":0"][offset, :]) * weight_value + else: + assert False + if value_aggregation is None: + assert weight_aggregation is None + assert squared_weight_aggregation is None + value_aggregation = val + weight_aggregation = weight_value + squared_weight_aggregation = weight_value * weight_value + else: + assert weight_aggregation is not None + assert squared_weight_aggregation is not None + value_aggregation += val + weight_aggregation += weight_value + squared_weight_aggregation += weight_value * weight_value + values.append(value_aggregation) + weights.append(weight_aggregation) + weights_squared.append(squared_weight_aggregation) + values = np.array(values).astype(np.float32) + weights = np.array(weights).astype(np.float32) + weights_squared = np.array(weights_squared).astype(np.float32) + return values, weights, weights_squared + + +class EmbeddingLookupSparseWithDistributedAggregationTest(test.TestCase): + + def _RandomIdsAndWeights(self, batch_size, vocab_size): + max_val_per_entry = 6 + vals_per_batch_entry = np.random.randint( + 1, max_val_per_entry, size=batch_size) + num_vals = np.sum(vals_per_batch_entry) + + ids = np.random.randint(vocab_size, size=num_vals) + weights = 1 + np.random.rand(num_vals) + + indices = [] + for batch_entry, num_val in enumerate(vals_per_batch_entry): + for val_index in range(num_val): + indices.append([batch_entry, val_index]) + + shape = [batch_size, max_val_per_entry] + + sp_ids = sparse_tensor_lib.SparseTensor( + constant_op.constant(indices, dtypes.int64), + constant_op.constant(ids, dtypes.int32), + constant_op.constant(shape, dtypes.int64)) + sp_weights = sparse_tensor_lib.SparseTensor( + constant_op.constant(indices, dtypes.int64), + constant_op.constant(weights, dtypes.float32), + constant_op.constant(shape, dtypes.int64)) + + return sp_ids, sp_weights, ids, weights, vals_per_batch_entry + + def _GroupByBatchEntry(self, vals, vals_per_batch_entry): + grouped_vals = [] + index = 0 + for num_val in vals_per_batch_entry: + grouped_vals.append(list(vals[index:(index + num_val)])) + index += num_val + return grouped_vals + + def testEmbeddingLookupSparse(self): + vocab_size = 13 + batch_size = 10 + param_shape = [2, 5] + expected_lookup_result_shape = [None] + param_shape + + sp_ids, sp_weights, ids, weights, vals_per_batch_entry = ( + self._RandomIdsAndWeights(batch_size, vocab_size)) + + grouped_ids = self._GroupByBatchEntry(ids, vals_per_batch_entry) + grouped_weights = self._GroupByBatchEntry(weights, vals_per_batch_entry) + grouped_ignored_weights = self._GroupByBatchEntry( + np.ones(np.sum(vals_per_batch_entry)), vals_per_batch_entry) + + for num_shards, combiner, dtype, ignore_weights in itertools.product( + [1, 5], ["sum", "mean", "sqrtn"], [dtypes.float32, + dtypes.float64], [True, False]): + + with self.test_session(): + p, params, feed_dict = _EmbeddingParams( + num_shards, vocab_size, shape=param_shape, dtype=dtype) + embedding_sum = \ + embedding_ops.embedding_lookup_sparse_with_distributed_aggregation( + p, + sp_ids, + None if ignore_weights else sp_weights, + combiner=combiner) + + self.assertEqual(embedding_sum.get_shape().as_list(), + expected_lookup_result_shape) + + tf_embedding_sum = embedding_sum.eval(feed_dict=feed_dict) + + np_embedding_sum, np_weight_sum, np_weight_sq_sum = _EmbeddingResult( + params, + grouped_ids, + num_shards, + vocab_size, + weight_vals=grouped_ignored_weights + if ignore_weights else grouped_weights) + if combiner == "mean": + np_embedding_sum /= np.reshape(np_weight_sum, (batch_size, 1, 1)) + if combiner == "sqrtn": + np_embedding_sum /= np.reshape( + np.sqrt(np_weight_sq_sum), (batch_size, 1, 1)) + self.assertAllClose(np_embedding_sum, tf_embedding_sum) + + def testGradientsEmbeddingLookupSparse(self): + vocab_size = 12 + batch_size = 4 + param_shape = [2, 3] + sp_ids, sp_weights, _, _, _ = (self._RandomIdsAndWeights( + batch_size, vocab_size)) + + for num_shards, combiner, dtype, ignore_weights in itertools.product( + [1, 3], ["sum", "mean", "sqrtn"], [dtypes.float32, + dtypes.float64], [True, False]): + with self.test_session(): + x, params, _ = _EmbeddingParams( + num_shards, vocab_size, shape=param_shape, dtype=dtype) + + y = embedding_ops.embedding_lookup_sparse_with_distributed_aggregation( + x, + sp_ids, + None if ignore_weights else sp_weights, + combiner=combiner) + x_name = [_PName(i) for i in range(num_shards)] + x_init_value = [params[x_n + ":0"] for x_n in x_name] + x_shape = [i.shape for i in x_init_value] + y_shape = [batch_size] + list(params[_PName(0) + ":0"].shape[1:]) + err = gradient_checker.compute_gradient_error( + x, x_shape, y, y_shape, x_init_value=x_init_value) + self.assertLess(err, 1e-5 if dtype == dtypes.float64 else 2e-3) + + def testIncompatibleShapes(self): + with self.test_session(): + x, _, _ = _EmbeddingParams(1, 10, dtype=dtypes.float32) + sp_ids = sparse_tensor_lib.SparseTensor( + constant_op.constant([[0, 0], [0, 1], [1, 0]], dtypes.int64), + constant_op.constant([0, 1, 2], dtypes.int32), + constant_op.constant([2, 2], dtypes.int64)) + sp_weights = sparse_tensor_lib.SparseTensor( + constant_op.constant([[0, 0], [0, 1]], dtypes.int64), + constant_op.constant([12.0, 5.0], dtypes.float32), + constant_op.constant([1, 2], dtypes.int64)) + + with self.assertRaises(ValueError): + embedding_ops.embedding_lookup_sparse_with_distributed_aggregation( + x, sp_ids, sp_weights, combiner="mean") + + if __name__ == "__main__": test.main() |