1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
|
"""Tests for CandidateSamplerOp."""
import tensorflow.python.platform
import numpy as np
import tensorflow as tf
class RangeSamplerOpsTest(tf.test.TestCase):
BATCH_SIZE = 3
NUM_TRUE = 2
RANGE = 5
NUM_SAMPLED = RANGE
TRUE_LABELS = [[1, 2], [0, 4], [3, 3]]
def testTrueCandidates(self):
with self.test_session() as sess:
indices = tf.constant([0, 0, 1, 1, 2, 2])
true_candidates_vec = tf.constant([1, 2, 0, 4, 3, 3])
true_candidates_matrix = tf.reshape(
true_candidates_vec, [self.BATCH_SIZE, self.NUM_TRUE])
indices_val, true_candidates_val = sess.run(
[indices, true_candidates_matrix])
self.assertAllEqual(indices_val, [0, 0, 1, 1, 2, 2])
self.assertAllEqual(true_candidates_val, self.TRUE_LABELS)
def testSampledCandidates(self):
with self.test_session():
true_classes = tf.constant([[1, 2], [0, 4], [3, 3]],
dtype=tf.int64)
sampled_candidates, _, _ = tf.nn.all_candidate_sampler(
true_classes, self.NUM_TRUE, self.NUM_SAMPLED, True)
result = sampled_candidates.eval()
expected_ids = [0, 1, 2, 3, 4]
self.assertAllEqual(result, expected_ids)
self.assertEqual(sampled_candidates.get_shape(), [self.NUM_SAMPLED])
def testTrueLogExpectedCount(self):
with self.test_session():
true_classes = tf.constant([[1, 2], [0, 4], [3, 3]],
dtype=tf.int64)
_, true_expected_count, _ = tf.nn.all_candidate_sampler(
true_classes, self.NUM_TRUE, self.NUM_SAMPLED, True)
true_log_expected_count = tf.log(true_expected_count)
result = true_log_expected_count.eval()
self.assertAllEqual(result, [[0.0] * self.NUM_TRUE] * self.BATCH_SIZE)
self.assertEqual(true_expected_count.get_shape(), [self.BATCH_SIZE,
self.NUM_TRUE])
self.assertEqual(true_log_expected_count.get_shape(), [self.BATCH_SIZE,
self.NUM_TRUE])
def testSampledLogExpectedCount(self):
with self.test_session():
true_classes = tf.constant([[1, 2], [0, 4], [3, 3]],
dtype=tf.int64)
_, _, sampled_expected_count = tf.nn.all_candidate_sampler(
true_classes, self.NUM_TRUE, self.NUM_SAMPLED, True)
sampled_log_expected_count = tf.log(sampled_expected_count)
result = sampled_log_expected_count.eval()
self.assertAllEqual(result, [0.0] * self.NUM_SAMPLED)
self.assertEqual(sampled_expected_count.get_shape(), [self.NUM_SAMPLED])
self.assertEqual(sampled_log_expected_count.get_shape(), [self.NUM_SAMPLED])
def testAccidentalHits(self):
with self.test_session() as sess:
true_classes = tf.constant([[1, 2], [0, 4], [3, 3]],
dtype=tf.int64)
sampled_candidates, _, _ = tf.nn.all_candidate_sampler(
true_classes, self.NUM_TRUE, self.NUM_SAMPLED, True)
accidental_hits = tf.nn.compute_accidental_hits(
true_classes, sampled_candidates, self.NUM_TRUE)
indices, ids, weights = sess.run(accidental_hits)
self.assertEqual(1, accidental_hits[0].get_shape().ndims)
self.assertEqual(1, accidental_hits[1].get_shape().ndims)
self.assertEqual(1, accidental_hits[2].get_shape().ndims)
for index, id_, weight in zip(indices, ids, weights):
self.assertTrue(id_ in self.TRUE_LABELS[index])
self.assertLess(weight, -1.0e37)
def testSeed(self):
def draw(seed):
with self.test_session():
true_classes = tf.constant([[1, 2], [0, 4], [3, 3]],
dtype=tf.int64)
sampled, _, _ = tf.nn.log_uniform_candidate_sampler(
true_classes,
self.NUM_TRUE,
self.NUM_SAMPLED,
True,
5,
seed=seed)
return sampled.eval()
# Non-zero seed. Repeatable.
for seed in [1, 12, 123, 1234]:
self.assertAllEqual(draw(seed), draw(seed))
# Seed=0 means random seeds.
num_same = 0
for _ in range(10):
if np.allclose(draw(None), draw(None)):
num_same += 1
# Accounts for the fact that the same random seed may be picked
# twice very rarely.
self.assertLessEqual(num_same, 2)
if __name__ == "__main__":
tf.test.main()
|