aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/python/kernel_tests/candidate_sampler_ops_test.py
blob: a36b8587d509d5c771b96777fedc75f6c91810e1 (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
"""Tests for CandidateSamplerOp."""
import tensorflow.python.platform

import numpy as np
import tensorflow as tf


class RangeSamplerOpsTest(tf.test.TestCase):

  BATCH_SIZE = 3
  NUM_TRUE = 2
  RANGE = 5
  NUM_SAMPLED = RANGE

  TRUE_LABELS = [[1, 2], [0, 4], [3, 3]]

  def testTrueCandidates(self):
    with self.test_session() as sess:
      indices = tf.constant([0, 0, 1, 1, 2, 2])
      true_candidates_vec = tf.constant([1, 2, 0, 4, 3, 3])
      true_candidates_matrix = tf.reshape(
          true_candidates_vec, [self.BATCH_SIZE, self.NUM_TRUE])
      indices_val, true_candidates_val = sess.run(
          [indices, true_candidates_matrix])

    self.assertAllEqual(indices_val, [0, 0, 1, 1, 2, 2])
    self.assertAllEqual(true_candidates_val, self.TRUE_LABELS)

  def testSampledCandidates(self):
    with self.test_session():
      true_classes = tf.constant([[1, 2], [0, 4], [3, 3]],
                                 dtype=tf.int64)
      sampled_candidates, _, _ = tf.nn.all_candidate_sampler(
          true_classes, self.NUM_TRUE, self.NUM_SAMPLED, True)
      result = sampled_candidates.eval()

    expected_ids = [0, 1, 2, 3, 4]
    self.assertAllEqual(result, expected_ids)
    self.assertEqual(sampled_candidates.get_shape(), [self.NUM_SAMPLED])

  def testTrueLogExpectedCount(self):
    with self.test_session():
      true_classes = tf.constant([[1, 2], [0, 4], [3, 3]],
                                 dtype=tf.int64)
      _, true_expected_count, _ = tf.nn.all_candidate_sampler(
          true_classes, self.NUM_TRUE, self.NUM_SAMPLED, True)
      true_log_expected_count = tf.log(true_expected_count)
      result = true_log_expected_count.eval()

    self.assertAllEqual(result, [[0.0] * self.NUM_TRUE] * self.BATCH_SIZE)
    self.assertEqual(true_expected_count.get_shape(), [self.BATCH_SIZE,
                                                       self.NUM_TRUE])
    self.assertEqual(true_log_expected_count.get_shape(), [self.BATCH_SIZE,
                                                           self.NUM_TRUE])

  def testSampledLogExpectedCount(self):
    with self.test_session():
      true_classes = tf.constant([[1, 2], [0, 4], [3, 3]],
                                 dtype=tf.int64)
      _, _, sampled_expected_count = tf.nn.all_candidate_sampler(
          true_classes, self.NUM_TRUE, self.NUM_SAMPLED, True)
      sampled_log_expected_count = tf.log(sampled_expected_count)
      result = sampled_log_expected_count.eval()

    self.assertAllEqual(result, [0.0] * self.NUM_SAMPLED)
    self.assertEqual(sampled_expected_count.get_shape(), [self.NUM_SAMPLED])
    self.assertEqual(sampled_log_expected_count.get_shape(), [self.NUM_SAMPLED])

  def testAccidentalHits(self):
    with self.test_session() as sess:
      true_classes = tf.constant([[1, 2], [0, 4], [3, 3]],
                                          dtype=tf.int64)
      sampled_candidates, _, _ = tf.nn.all_candidate_sampler(
          true_classes, self.NUM_TRUE, self.NUM_SAMPLED, True)
      accidental_hits = tf.nn.compute_accidental_hits(
          true_classes, sampled_candidates, self.NUM_TRUE)
      indices, ids, weights = sess.run(accidental_hits)

    self.assertEqual(1, accidental_hits[0].get_shape().ndims)
    self.assertEqual(1, accidental_hits[1].get_shape().ndims)
    self.assertEqual(1, accidental_hits[2].get_shape().ndims)
    for index, id_, weight in zip(indices, ids, weights):
      self.assertTrue(id_ in self.TRUE_LABELS[index])
      self.assertLess(weight, -1.0e37)

  def testSeed(self):

    def draw(seed):
      with self.test_session():
        true_classes = tf.constant([[1, 2], [0, 4], [3, 3]],
                                   dtype=tf.int64)
        sampled, _, _ = tf.nn.log_uniform_candidate_sampler(
            true_classes,
            self.NUM_TRUE,
            self.NUM_SAMPLED,
            True,
            5,
            seed=seed)
        return sampled.eval()
    # Non-zero seed. Repeatable.
    for seed in [1, 12, 123, 1234]:
      self.assertAllEqual(draw(seed), draw(seed))
    # Seed=0 means random seeds.
    num_same = 0
    for _ in range(10):
      if np.allclose(draw(None), draw(None)):
        num_same += 1
    # Accounts for the fact that the same random seed may be picked
    # twice very rarely.
    self.assertLessEqual(num_same, 2)


if __name__ == "__main__":
  tf.test.main()