aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/contrib/layers/python/layers/embedding_ops_test.py
blob: bf2514498202e9227c2d74c036c7eecba5ccdf2c (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
584
585
586
587
588
589
590
591
592
593
594
595
596
597
598
599
600
601
602
603
604
605
606
607
608
609
610
611
612
613
614
615
616
617
618
619
620
621
622
623
624
625
626
627
628
629
630
631
632
633
634
635
636
637
638
639
640
641
642
643
644
645
646
647
648
649
650
651
652
653
654
655
656
657
658
659
660
661
662
663
664
665
666
667
668
669
670
671
672
673
674
675
676
677
678
679
680
681
682
683
684
685
686
687
688
689
690
691
692
693
694
695
696
697
698
699
700
701
702
703
704
705
706
707
708
709
710
711
712
713
714
715
716
717
718
719
720
721
722
723
724
725
726
727
728
729
730
731
732
733
734
735
736
737
738
739
740
741
742
743
744
745
746
747
748
749
750
751
752
753
754
755
756
757
758
759
760
761
762
763
764
765
766
767
768
769
770
771
772
773
774
775
776
777
778
779
780
781
782
783
784
785
786
# Copyright 2016 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""embedding_ops tests."""

# pylint: disable=unused-import
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function

import itertools
import math
import sys

import numpy as np

from tensorflow.contrib.layers.python.layers import embedding_ops
from tensorflow.python.framework import constant_op
from tensorflow.python.framework import dtypes
from tensorflow.python.framework import errors_impl
from tensorflow.python.framework import random_seed
from tensorflow.python.framework import sparse_tensor as sparse_tensor_lib
from tensorflow.python.ops import array_ops
from tensorflow.python.ops import gradient_checker
from tensorflow.python.ops import init_ops
from tensorflow.python.ops import math_ops
from tensorflow.python.ops import partitioned_variables
from tensorflow.python.platform import test
from tensorflow.python.util import compat


class SafeEmbeddingLookupSparseTest(test.TestCase):

  def _random_weights(self, vocab_size=4, embed_dim=4, num_shards=1):
    assert vocab_size > 0
    assert embed_dim > 0
    assert num_shards > 0
    assert num_shards <= vocab_size

    embedding_weights = partitioned_variables.create_partitioned_variables(
        shape=[vocab_size, embed_dim],
        slicing=[num_shards, 1],
        initializer=init_ops.truncated_normal_initializer(
            mean=0.0, stddev=1.0 / math.sqrt(vocab_size), dtype=dtypes.float32))
    for w in embedding_weights:
      w.initializer.run()
    embedding_weights = [w.eval() for w in embedding_weights]
    return embedding_weights

  def _ids_and_weights_2d(self):
    # Each row demonstrates a test case:
    #   Row 0: multiple valid ids, 1 invalid id, weighted mean
    #   Row 1: all ids are invalid (leaving no valid ids after pruning)
    #   Row 2: no ids to begin with
    #   Row 3: single id
    #   Row 4: all ids have <=0 weight
    indices = [[0, 0], [0, 1], [0, 2], [1, 0], [3, 0], [4, 0], [4, 1]]
    ids = [0, 1, -1, -1, 2, 0, 1]
    weights = [1.0, 2.0, 1.0, 1.0, 3.0, 0.0, -0.5]
    shape = [5, 4]

    sparse_ids = sparse_tensor_lib.SparseTensor(
        constant_op.constant(indices, dtypes.int64),
        constant_op.constant(ids, dtypes.int64),
        constant_op.constant(shape, dtypes.int64))

    sparse_weights = sparse_tensor_lib.SparseTensor(
        constant_op.constant(indices, dtypes.int64),
        constant_op.constant(weights, dtypes.float32),
        constant_op.constant(shape, dtypes.int64))

    return sparse_ids, sparse_weights

  def _ids_and_weights_3d(self):
    # Each (2-D) index demonstrates a test case:
    #   Index 0, 0: multiple valid ids, 1 invalid id, weighted mean
    #   Index 0, 1: all ids are invalid (leaving no valid ids after pruning)
    #   Index 0, 2: no ids to begin with
    #   Index 1, 0: single id
    #   Index 1, 1: all ids have <=0 weight
    #   Index 1, 2: no ids to begin with
    indices = [[0, 0, 0], [0, 0, 1], [0, 0, 2], [0, 1, 0], [1, 0, 0], [1, 1, 0],
               [1, 1, 1]]
    ids = [0, 1, -1, -1, 2, 0, 1]
    weights = [1.0, 2.0, 1.0, 1.0, 3.0, 0.0, -0.5]
    shape = [2, 3, 4]

    sparse_ids = sparse_tensor_lib.SparseTensor(
        constant_op.constant(indices, dtypes.int64),
        constant_op.constant(ids, dtypes.int64),
        constant_op.constant(shape, dtypes.int64))

    sparse_weights = sparse_tensor_lib.SparseTensor(
        constant_op.constant(indices, dtypes.int64),
        constant_op.constant(weights, dtypes.float32),
        constant_op.constant(shape, dtypes.int64))

    return sparse_ids, sparse_weights

  def test_safe_embedding_lookup_sparse_return_zero_vector(self):
    with self.test_session():
      embedding_weights = self._random_weights()
      sparse_ids, sparse_weights = self._ids_and_weights_2d()

      embedding_lookup_result = (embedding_ops.safe_embedding_lookup_sparse(
          embedding_weights, sparse_ids, sparse_weights).eval())

      self.assertAllClose(
          embedding_lookup_result,
          [(1.0 * embedding_weights[0][0] + 2.0 * embedding_weights[0][1]) /
           3.0, [0] * 4, [0] * 4, embedding_weights[0][2], [0] * 4])

  def test_safe_embedding_lookup_sparse_return_special_vector(self):
    with self.test_session():
      embedding_weights = self._random_weights()
      sparse_ids, sparse_weights = self._ids_and_weights_2d()

      embedding_lookup_result = (embedding_ops.safe_embedding_lookup_sparse(
          embedding_weights, sparse_ids, sparse_weights, default_id=3).eval())

      self.assertAllClose(
          embedding_lookup_result,
          [(1.0 * embedding_weights[0][0] + 2.0 * embedding_weights[0][1]) /
           3.0, embedding_weights[0][3], embedding_weights[0][3],
           embedding_weights[0][2], embedding_weights[0][3]])

  def test_safe_embedding_lookup_sparse_no_weights(self):
    with self.test_session():
      embedding_weights = self._random_weights()
      sparse_ids, _ = self._ids_and_weights_2d()

      embedding_lookup_result = (embedding_ops.safe_embedding_lookup_sparse(
          embedding_weights, sparse_ids, None).eval())

      self.assertAllClose(
          embedding_lookup_result,
          [(embedding_weights[0][0] + embedding_weights[0][1]) / 2.0, [0] * 4,
           [0] * 4, embedding_weights[0][2], (
               embedding_weights[0][0] + embedding_weights[0][1]) / 2.0])

  def test_safe_embedding_lookup_sparse_partitioned(self):
    with self.test_session():
      embedding_weights = self._random_weights(num_shards=3)
      sparse_ids, _ = self._ids_and_weights_2d()

      embedding_lookup_result = (embedding_ops.safe_embedding_lookup_sparse(
          embedding_weights, sparse_ids, None).eval())

      embedding_weights = list(itertools.chain(*embedding_weights))
      self.assertAllClose(embedding_lookup_result,
                          [(embedding_weights[0] + embedding_weights[1]) / 2.0,
                           [0] * 4, [0] * 4, embedding_weights[2],
                           (embedding_weights[0] + embedding_weights[1]) / 2.0])

  def test_safe_embedding_lookup_sparse_partitioned_inconsistent_weights(self):
    with self.test_session():
      embedding_weights = self._random_weights(num_shards=3)
      sparse_ids, sparse_weights = self._ids_and_weights_2d()

      embedding_weights[1] = embedding_weights[1].astype(np.float64)
      self.assertRaises(ValueError, embedding_ops.safe_embedding_lookup_sparse,
                        embedding_weights, sparse_ids)
      embedding_weights = [
          constant_op.constant(w, dtype=dtypes.float64)
          for w in embedding_weights
      ]
      self.assertRaises(ValueError, embedding_ops.safe_embedding_lookup_sparse,
                        embedding_weights, sparse_ids, sparse_weights)

  def test_safe_embedding_lookup_sparse_3d_return_zero_vector(self):
    with self.test_session():
      embedding_weights = self._random_weights()
      sparse_ids, sparse_weights = self._ids_and_weights_3d()

      embedding_lookup_result = (embedding_ops.safe_embedding_lookup_sparse(
          embedding_weights, sparse_ids, sparse_weights).eval())

      self.assertAllClose(embedding_lookup_result, [[
          (1.0 * embedding_weights[0][0] + 2.0 * embedding_weights[0][1]) / 3.0,
          [0] * 4, [0] * 4
      ], [embedding_weights[0][2], [0] * 4, [0] * 4]])

  def test_safe_embedding_lookup_sparse_3d_return_special_vector(self):
    with self.test_session():
      embedding_weights = self._random_weights()
      sparse_ids, sparse_weights = self._ids_and_weights_3d()

      embedding_lookup_result = (embedding_ops.safe_embedding_lookup_sparse(
          embedding_weights, sparse_ids, sparse_weights, default_id=3).eval())

      self.assertAllClose(
          embedding_lookup_result,
          [[(1.0 * embedding_weights[0][0] + 2.0 * embedding_weights[0][1]) /
            3.0, embedding_weights[0][3], embedding_weights[0][3]], [
                embedding_weights[0][2], embedding_weights[0][3],
                embedding_weights[0][3]
            ]])

  def test_safe_embedding_lookup_sparse_3d_no_weights(self):
    with self.test_session():
      embedding_weights = self._random_weights()
      sparse_ids, _ = self._ids_and_weights_3d()

      embedding_lookup_result = (embedding_ops.safe_embedding_lookup_sparse(
          embedding_weights, sparse_ids, None).eval())

      self.assertAllClose(embedding_lookup_result, [[(
          embedding_weights[0][0] + embedding_weights[0][1]) / 2.0, [0] * 4, [
              0
          ] * 4], [
              embedding_weights[0][2],
              (embedding_weights[0][0] + embedding_weights[0][1]) / 2.0, [0] * 4
          ]])

  def test_safe_embedding_lookup_sparse_3d_partitioned(self):
    with self.test_session():
      embedding_weights = self._random_weights(num_shards=3)
      sparse_ids, _ = self._ids_and_weights_3d()

      embedding_lookup_result = (embedding_ops.safe_embedding_lookup_sparse(
          embedding_weights, sparse_ids, None).eval())

      embedding_weights = list(itertools.chain(*embedding_weights))
      self.assertAllClose(embedding_lookup_result, [[
          (embedding_weights[0] + embedding_weights[1]) / 2.0, [0] * 4, [0] * 4
      ], [
          embedding_weights[2],
          (embedding_weights[0] + embedding_weights[1]) / 2.0, [0] * 4
      ]])

  def test_safe_embedding_lookup_sparse_3d_partitioned_inconsistent_weights(
      self):
    with self.test_session():
      embedding_weights = self._random_weights(num_shards=3)
      sparse_ids, sparse_weights = self._ids_and_weights_3d()

      embedding_weights[1] = embedding_weights[1].astype(np.float64)
      self.assertRaises(ValueError, embedding_ops.safe_embedding_lookup_sparse,
                        embedding_weights, sparse_ids)
      embedding_weights = [
          constant_op.constant(w, dtype=dtypes.float64)
          for w in embedding_weights
      ]
      self.assertRaises(ValueError, embedding_ops.safe_embedding_lookup_sparse,
                        embedding_weights, sparse_ids, sparse_weights)


class ScatteredEmbeddingLookupTest(test.TestCase):

  def setUp(self):
    random_seed.set_random_seed(1)

  def _random_weights(self, size=50, num_shards=1):
    assert size > 0
    assert num_shards > 0
    assert num_shards <= size

    embedding_weights = partitioned_variables.create_partitioned_variables(
        shape=[size],
        slicing=[num_shards],
        initializer=init_ops.truncated_normal_initializer(
            mean=0.0, stddev=1.0, dtype=dtypes.float32))
    for w in embedding_weights:
      w.initializer.run()
    return embedding_weights

  def test_scattered_embedding_consistency(self):
    with self.test_session():
      embedding_weights = self._random_weights()
      values = constant_op.constant(["foo", "foo"])

      embedding_lookup_result = embedding_ops.scattered_embedding_lookup(
          embedding_weights, values, dimension=10).eval()

      self.assertAllEqual(embedding_lookup_result.shape, [2, 10])
      self.assertAllEqual(embedding_lookup_result[0],
                          embedding_lookup_result[1])

  def test_scattered_embedding_multiple_partition(self):
    with self.test_session():
      embedding_weights = self._random_weights(num_shards=7)
      values = constant_op.constant([4, 4, 5])

      embedding_lookup_result = embedding_ops.scattered_embedding_lookup(
          embedding_weights, values, dimension=5).eval()

      self.assertAllEqual(embedding_lookup_result.shape, [3, 5])
      self.assertAllEqual(embedding_lookup_result[0],
                          embedding_lookup_result[1])
      # Different embedding expected for different value.
      embedding_diff = np.min(
          (embedding_lookup_result[2] - embedding_lookup_result[0])**2)
      self.assertGreater(embedding_diff, 0)

  def test_scattered_embedding_coverage(self):
    with self.test_session():
      size = 8
      embedding_weights = self._random_weights(size=size, num_shards=3)
      values = constant_op.constant(["foo"])

      # Large embedding dimension to cover the full range of weights.
      embedding_lookup_result = embedding_ops.scattered_embedding_lookup(
          embedding_weights, values, dimension=100).eval()

      self.assertEqual(len(np.unique(embedding_lookup_result[0])), size)

  def test_scattered_embedding_multi_dimension(self):
    with self.test_session():
      embedding_weights = self._random_weights()
      values = constant_op.constant([["foo", "bar", "bar"],
                                     ["bar", "bar", "foo"]])

      embedding_lookup_result = embedding_ops.scattered_embedding_lookup(
          embedding_weights, values, dimension=10).eval()

      self.assertAllEqual(embedding_lookup_result.shape, [2, 3, 10])
      self.assertAllEqual(embedding_lookup_result[0][0],
                          embedding_lookup_result[1][2])

  def test_scattered_embedding_lookup_sparse(self):
    with self.test_session():
      embedding_weights = self._random_weights(num_shards=3)
      sparse_tensor = sparse_tensor_lib.SparseTensor(
          values=["foo", "bar", "foo", "bar"],
          indices=[[0, 0], [1, 0], [1, 1], [3, 0]],
          dense_shape=[5, 2])

      embedding_lookup_result = (
          embedding_ops.scattered_embedding_lookup_sparse(
              embedding_weights, sparse_tensor, dimension=5,
              combiner="mean").eval())

      self.assertAllEqual(embedding_lookup_result.shape, [5, 5])
      # Same non-zero embedding for the empty rows filled with a default value.
      self.assertAllEqual(embedding_lookup_result[2],
                          embedding_lookup_result[4])
      embedding_norm = np.sum(embedding_lookup_result[2]**2)
      self.assertGreater(embedding_norm, 0)

      self.assertAllEqual(embedding_lookup_result[1], 0.5 * (
          embedding_lookup_result[0] + embedding_lookup_result[3]))

  def test_embedding_lookup_unique(self):
    d_embed = 5
    n_embed = 10
    idx_shape = (2, 3, 4)
    embeds = np.random.randn(n_embed, d_embed)
    idx = np.random.randint(0, n_embed, idx_shape)

    with self.test_session():
      embedded_np = embeds[idx]
      embedded_tf = embedding_ops.embedding_lookup_unique(embeds, idx).eval()

    self.assertEqual(embedded_np.shape, embedded_tf.shape)
    np.testing.assert_almost_equal(embedded_np, embedded_tf)

  def test_embedding_lookup_unique_param3d(self):
    embeds = np.random.randn(5, 3, 3)
    idx = np.random.randint(0, 5, 10)
    idx2d = np.random.randint(0, 5, (10, 2))

    with self.test_session():
      embedded_np = embeds[idx]
      embedded_np2d = embeds[idx2d]
      embedded_tf = embedding_ops.embedding_lookup_unique(embeds, idx).eval()
      embedded_tf_lst = embedding_ops.embedding_lookup_unique([embeds],
                                                              idx).eval()
      embedded_tf2d = embedding_ops.embedding_lookup_unique(embeds,
                                                            idx2d).eval()

    self.assertEqual(embedded_np.shape, embedded_tf.shape)
    np.testing.assert_almost_equal(embedded_np, embedded_tf)
    self.assertEqual(embedded_np.shape, embedded_tf_lst.shape)
    np.testing.assert_almost_equal(embedded_np, embedded_tf_lst)
    self.assertEqual(embedded_np2d.shape, embedded_tf2d.shape)
    np.testing.assert_almost_equal(embedded_np2d, embedded_tf2d)


class SampledScatteredEmbeddingLookupTest(test.TestCase):

  def setUp(self):
    random_seed.set_random_seed(1)
    self._hash_key = 1

  def _random_weights(self, size=50, num_shards=1):
    assert size > 0
    assert num_shards > 0
    assert num_shards <= size

    embedding_weights = partitioned_variables.create_partitioned_variables(
        shape=[size],
        slicing=[num_shards],
        initializer=init_ops.truncated_normal_initializer(
            mean=0.0, stddev=1.0, dtype=dtypes.float32))
    for w in embedding_weights:
      w.initializer.run()
    return embedding_weights

  def test_hashed_embedding_consistency(self):
    with self.test_session():
      embedding_weights = self._random_weights()
      values = constant_op.constant(["foo", "foo"])
      # The first three sampled_candidates are equal, so the first three
      # embedding weights will be equal.
      sampled_candidates = constant_op.constant([[1, 3, 4, 6], [1, 3, 4, 7]])

      embedding_lookup_result = (  # pylint: disable=protected-access
          embedding_ops._sampled_scattered_embedding_lookup(
              embedding_weights,
              values,
              sampled_candidates=sampled_candidates,
              hash_key=self._hash_key).eval())

      self.assertAllEqual(embedding_lookup_result.shape, [2, 4])
      self.assertAllEqual(embedding_lookup_result[0][:3],
                          embedding_lookup_result[1][:3])
      self.assertNotEqual(embedding_lookup_result[0][3],
                          embedding_lookup_result[1][3])

  def test_hashed_embedding_multi_dimension(self):
    with self.test_session():
      embedding_weights = self._random_weights()
      values = constant_op.constant([["foo", "bar", "bar"],
                                     ["bar", "bar", "foo"]])
      sampled_candidates = constant_op.constant(
          [[[1, 3, 4, 6], [1, 7, 8, 9], [1, 7, 8, 9]],
           [[1, 7, 8, 9], [1, 7, 8, 9], [1, 3, 4, 6]]])

      embedding_lookup_result = (  # pylint: disable=protected-access
          embedding_ops._sampled_scattered_embedding_lookup(
              embedding_weights,
              values,
              sampled_candidates=sampled_candidates,
              hash_key=self._hash_key).eval())

      self.assertAllEqual(embedding_lookup_result.shape, [2, 3, 4])
      self.assertAllEqual(embedding_lookup_result[0][0],
                          embedding_lookup_result[1][2])

      invalid_indices = constant_op.constant([[[1, 3, 4, 6], [1, 7, 8, 9]],
                                              [[1, 7, 8, 9], [1, 7, 8, 9]]])
      with self.assertRaisesRegexp(errors_impl.InvalidArgumentError, (
          r"\[The shape of sampled_candidates: \] \[2 2 4\] "
          r"\[ does not match the shape of values: \] \[2 3\]")):
        # pylint: disable=protected-access
        embedding_ops._sampled_scattered_embedding_lookup(
            embedding_weights, values,
            sampled_candidates=invalid_indices).eval()


class SampledScatteredEmbeddingLookupSparseTest(test.TestCase):

  def setUp(self):
    random_seed.set_random_seed(1)
    self._hash_key = 1

  def test_output_shape(self):
    """Verifies the shape of the output tensor."""
    with self.test_session():
      sp_values = sparse_tensor_lib.SparseTensor(
          values=["a", "a", "b", "c", "d", "e", "f"],
          indices=[[1, 0], [2, 0], [2, 1], [2, 2], [2, 3], [2, 4], [2, 5]],
          dense_shape=[3, 6])
      params = constant_op.constant([.1, .2, .3])

      result = embedding_ops._sampled_scattered_embedding_lookup_sparse(
          params, sp_values, dimension=4, hash_key=self._hash_key)

      self.assertEqual(result.eval().shape, (3, 4))

  def test_output_values(self):
    """Verifies the values in a trivial case."""
    with self.test_session():
      sp_values = sparse_tensor_lib.SparseTensor(
          values=["a"], indices=[[1, 0]], dense_shape=[3, 1])
      params = constant_op.constant([.1, .2, .3])

      result = embedding_ops._sampled_scattered_embedding_lookup_sparse(
          params, sp_values, dimension=5, hash_key=self._hash_key)

      self.assertAllClose(result.eval(), [[0., 0., 0., 0.,
                                           0.], [.3, .2, .2, .3, .1],
                                          [0., 0., 0., 0., 0.]])

  def test_output_values_with_sampled_candidates(self):
    """Verifies the values for given sampled_candidates."""
    with self.test_session():
      sp_values = sparse_tensor_lib.SparseTensor(
          values=["a", "a", "b", "c", "d", "e", "f"],
          indices=[[1, 0], [2, 0], [2, 1], [2, 2], [2, 3], [2, 4], [2, 5]],
          dense_shape=[3, 6])
      params = constant_op.constant([.1, .2, .3])

      sampled_candidates = [[1, 0], [2, 1], [3, 2]]
      sampled_result = embedding_ops._sampled_scattered_embedding_lookup_sparse(
          params,
          sp_values,
          sampled_candidates=constant_op.constant(sampled_candidates),
          hash_key=self._hash_key)
      full_result = embedding_ops._sampled_scattered_embedding_lookup_sparse(
          params, sp_values, dimension=4, hash_key=self._hash_key)

      sampled_result_val = sampled_result.eval()
      full_result_val = full_result.eval()
      self.assertEqual(sampled_result_val.shape, (3, 2))
      for i in range(len(sampled_candidates)):
        self.assertAllClose(sampled_result_val[i],
                            full_result_val[i, sampled_candidates[i]])

  def test_output_values_with_sign_hash(self):
    """Verifies the values in a trivial case with hash_signs=True."""
    with self.test_session():
      sp_values = sparse_tensor_lib.SparseTensor(
          values=["a"], indices=[[1, 0]], dense_shape=[3, 1])
      params = constant_op.constant([.1, .1, .1])

      result = embedding_ops._sampled_scattered_embedding_lookup_sparse(
          params,
          sp_values,
          dimension=4,
          with_sign_hash=True,
          hash_key=self._hash_key)

      self.assertAllClose(result.eval(), [[0., 0., 0., 0.], [-.1, -.1, -.1, .1],
                                          [0., 0., 0., 0.]])

  def test_distributive_property(self):
    """Verifies the distributive property of matrix multiplication."""
    with self.test_session():
      params = constant_op.constant([.1, .2, .3])
      sp_values_a = sparse_tensor_lib.SparseTensor(
          values=["a"], indices=[[0, 0]], dense_shape=[3, 1])
      sp_values_b = sparse_tensor_lib.SparseTensor(
          values=["b"], indices=[[2, 0]], dense_shape=[3, 1])
      sp_values_c = sparse_tensor_lib.SparseTensor(
          values=["c"], indices=[[2, 0]], dense_shape=[3, 1])
      sp_values = sparse_tensor_lib.SparseTensor(
          values=["a", "b", "c"],
          indices=[[0, 0], [2, 0], [2, 1]],
          dense_shape=[3, 2])

      result_a = embedding_ops._sampled_scattered_embedding_lookup_sparse(
          params, sp_values_a, dimension=4, hash_key=self._hash_key)
      result_b = embedding_ops._sampled_scattered_embedding_lookup_sparse(
          params, sp_values_b, dimension=4, hash_key=self._hash_key)
      result_c = embedding_ops._sampled_scattered_embedding_lookup_sparse(
          params, sp_values_c, dimension=4, hash_key=self._hash_key)
      result = embedding_ops._sampled_scattered_embedding_lookup_sparse(
          params, sp_values, dimension=4, hash_key=self._hash_key)

      result_abc = math_ops.add_n([result_a, result_b, result_c])
      self.assertAllClose(result.eval(), result_abc.eval())


def _PName(param_id):
  return "p" + str(param_id)


def _EmbeddingParams(num_shards,
                     vocab_size,
                     dtype=dtypes.float32,
                     shape=None,
                     use_shapeless_placeholder=False):
  p = []
  params = {}
  feed_dict = {}
  if not shape:
    shape = [10]
  for i in range(num_shards):
    shard_shape = [vocab_size // num_shards] + shape
    if i < vocab_size % num_shards:  # Excess goes evenly on the first shards
      shard_shape[0] += 1

    param_name = _PName(i)

    if use_shapeless_placeholder:
      param = array_ops.placeholder(dtype, shape=None, name=param_name)
    else:
      param = constant_op.constant(
          1.0, shape=shard_shape, dtype=dtype, name=param_name)
    p.append(param)
    np_type = "f" if dtype == dtypes.float32 else "d"
    val = (np.random.rand(*shard_shape).astype(np_type)) + 1
    params[param_name + ":0"] = val
    feed_dict[param.name] = val
  return p, params, feed_dict


def _EmbeddingResult(params,
                     id_vals,
                     num_shards,
                     vocab_size,
                     partition_strategy="mod",
                     weight_vals=None):
  if weight_vals is None:
    weight_vals = np.copy(id_vals)
    weight_vals.fill(1)
  values = []
  weights = []
  weights_squared = []
  for ids, wts in zip(id_vals, weight_vals):
    value_aggregation = None
    weight_aggregation = None
    squared_weight_aggregation = None
    if isinstance(ids, compat.integral_types):
      ids = [ids]
      wts = [wts]
    for i, weight_value in zip(ids, wts):
      if partition_strategy == "mod":
        val = np.copy(params[_PName(i % num_shards) + ":0"][
            i // num_shards, :]) * weight_value
      elif partition_strategy == "div":
        ids_per_partition, extras = divmod(vocab_size, num_shards)
        threshold = extras * (ids_per_partition + 1)
        if i < threshold:
          partition = i // (ids_per_partition + 1)
          offset = i % (ids_per_partition + 1)
        else:
          partition = extras + (i - threshold) // ids_per_partition
          offset = (i - threshold) % ids_per_partition
        val = np.copy(
            params[_PName(partition) + ":0"][offset, :]) * weight_value
      else:
        assert False
      if value_aggregation is None:
        assert weight_aggregation is None
        assert squared_weight_aggregation is None
        value_aggregation = val
        weight_aggregation = weight_value
        squared_weight_aggregation = weight_value * weight_value
      else:
        assert weight_aggregation is not None
        assert squared_weight_aggregation is not None
        value_aggregation += val
        weight_aggregation += weight_value
        squared_weight_aggregation += weight_value * weight_value
    values.append(value_aggregation)
    weights.append(weight_aggregation)
    weights_squared.append(squared_weight_aggregation)
  values = np.array(values).astype(np.float32)
  weights = np.array(weights).astype(np.float32)
  weights_squared = np.array(weights_squared).astype(np.float32)
  return values, weights, weights_squared


class EmbeddingLookupSparseWithDistributedAggregationTest(test.TestCase):

  def _RandomIdsAndWeights(self, batch_size, vocab_size):
    max_val_per_entry = 6
    vals_per_batch_entry = np.random.randint(
        1, max_val_per_entry, size=batch_size)
    num_vals = np.sum(vals_per_batch_entry)

    ids = np.random.randint(vocab_size, size=num_vals)
    weights = 1 + np.random.rand(num_vals)

    indices = []
    for batch_entry, num_val in enumerate(vals_per_batch_entry):
      for val_index in range(num_val):
        indices.append([batch_entry, val_index])

    shape = [batch_size, max_val_per_entry]

    sp_ids = sparse_tensor_lib.SparseTensor(
        constant_op.constant(indices, dtypes.int64),
        constant_op.constant(ids, dtypes.int32),
        constant_op.constant(shape, dtypes.int64))
    sp_weights = sparse_tensor_lib.SparseTensor(
        constant_op.constant(indices, dtypes.int64),
        constant_op.constant(weights, dtypes.float32),
        constant_op.constant(shape, dtypes.int64))

    return sp_ids, sp_weights, ids, weights, vals_per_batch_entry

  def _GroupByBatchEntry(self, vals, vals_per_batch_entry):
    grouped_vals = []
    index = 0
    for num_val in vals_per_batch_entry:
      grouped_vals.append(list(vals[index:(index + num_val)]))
      index += num_val
    return grouped_vals

  def testEmbeddingLookupSparse(self):
    vocab_size = 13
    batch_size = 10
    param_shape = [2, 5]
    expected_lookup_result_shape = [None] + param_shape

    sp_ids, sp_weights, ids, weights, vals_per_batch_entry = (
        self._RandomIdsAndWeights(batch_size, vocab_size))

    grouped_ids = self._GroupByBatchEntry(ids, vals_per_batch_entry)
    grouped_weights = self._GroupByBatchEntry(weights, vals_per_batch_entry)
    grouped_ignored_weights = self._GroupByBatchEntry(
        np.ones(np.sum(vals_per_batch_entry)), vals_per_batch_entry)

    for num_shards, combiner, dtype, ignore_weights in itertools.product(
        [1, 5], ["sum", "mean", "sqrtn"], [dtypes.float32,
                                           dtypes.float64], [True, False]):

      with self.test_session():
        p, params, feed_dict = _EmbeddingParams(
            num_shards, vocab_size, shape=param_shape, dtype=dtype)
        embedding_sum = \
            embedding_ops.embedding_lookup_sparse_with_distributed_aggregation(
                p,
                sp_ids,
                None if ignore_weights else sp_weights,
                combiner=combiner)

        self.assertEqual(embedding_sum.get_shape().as_list(),
                         expected_lookup_result_shape)

        tf_embedding_sum = embedding_sum.eval(feed_dict=feed_dict)

        np_embedding_sum, np_weight_sum, np_weight_sq_sum = _EmbeddingResult(
            params,
            grouped_ids,
            num_shards,
            vocab_size,
            weight_vals=grouped_ignored_weights
            if ignore_weights else grouped_weights)
        if combiner == "mean":
          np_embedding_sum /= np.reshape(np_weight_sum, (batch_size, 1, 1))
        if combiner == "sqrtn":
          np_embedding_sum /= np.reshape(
              np.sqrt(np_weight_sq_sum), (batch_size, 1, 1))
        self.assertAllClose(np_embedding_sum, tf_embedding_sum)

  def testGradientsEmbeddingLookupSparse(self):
    vocab_size = 12
    batch_size = 4
    param_shape = [2, 3]
    sp_ids, sp_weights, _, _, _ = (self._RandomIdsAndWeights(
        batch_size, vocab_size))

    for num_shards, combiner, dtype, ignore_weights in itertools.product(
        [1, 3], ["sum", "mean", "sqrtn"], [dtypes.float32,
                                           dtypes.float64], [True, False]):
      with self.test_session():
        x, params, _ = _EmbeddingParams(
            num_shards, vocab_size, shape=param_shape, dtype=dtype)

        y = embedding_ops.embedding_lookup_sparse_with_distributed_aggregation(
            x,
            sp_ids,
            None if ignore_weights else sp_weights,
            combiner=combiner)
        x_name = [_PName(i) for i in range(num_shards)]
        x_init_value = [params[x_n + ":0"] for x_n in x_name]
        x_shape = [i.shape for i in x_init_value]
        y_shape = [batch_size] + list(params[_PName(0) + ":0"].shape[1:])
        err = gradient_checker.compute_gradient_error(
            x, x_shape, y, y_shape, x_init_value=x_init_value)
      self.assertLess(err, 1e-5 if dtype == dtypes.float64 else 2e-3)

  def testIncompatibleShapes(self):
    with self.test_session():
      x, _, _ = _EmbeddingParams(1, 10, dtype=dtypes.float32)
      sp_ids = sparse_tensor_lib.SparseTensor(
          constant_op.constant([[0, 0], [0, 1], [1, 0]], dtypes.int64),
          constant_op.constant([0, 1, 2], dtypes.int32),
          constant_op.constant([2, 2], dtypes.int64))
      sp_weights = sparse_tensor_lib.SparseTensor(
          constant_op.constant([[0, 0], [0, 1]], dtypes.int64),
          constant_op.constant([12.0, 5.0], dtypes.float32),
          constant_op.constant([1, 2], dtypes.int64))

      with self.assertRaises(ValueError):
        embedding_ops.embedding_lookup_sparse_with_distributed_aggregation(
            x, sp_ids, sp_weights, combiner="mean")


if __name__ == "__main__":
  test.main()