aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/python/kernel_tests/listdiff_op_test.py
blob: b4607be1fb27c5a7ca3ad2033a57a587bc5cbeb0 (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
"""Tests for tensorflow.kernels.listdiff_op."""

import tensorflow.python.platform

import numpy as np
import tensorflow as tf


class ListDiffTest(tf.test.TestCase):

  def _testListDiff(self, x, y, out, idx, dtype=np.int32):
    x = np.array(x, dtype=dtype)
    y = np.array(y, dtype=dtype)
    out = np.array(out, dtype=dtype)
    idx = np.array(idx, dtype=dtype)

    with self.test_session() as sess:
      x_tensor = tf.convert_to_tensor(x)
      y_tensor = tf.convert_to_tensor(y)
      out_tensor, idx_tensor = tf.listdiff(x_tensor, y_tensor)
      tf_out, tf_idx = sess.run([out_tensor, idx_tensor])

    self.assertAllEqual(tf_out, out)
    self.assertAllEqual(tf_idx, idx)
    self.assertEqual(1, out_tensor.get_shape().ndims)
    self.assertEqual(1, idx_tensor.get_shape().ndims)

  def testBasic1(self):
    x = [1, 2, 3, 4]
    y = [1, 2]
    out = [3, 4]
    idx = [2, 3]
    for t in [np.int32, np.int64, np.float, np.double]:
      self._testListDiff(x, y, out, idx, dtype=t)

  def testBasic2(self):
    x = [1, 2, 3, 4]
    y = [2]
    out = [1, 3, 4]
    idx = [0, 2, 3]
    for t in [np.int32, np.int64, np.float, np.double]:
      self._testListDiff(x, y, out, idx, dtype=t)

  def testBasic3(self):
    x = [1, 4, 3, 2]
    y = [4, 2]
    out = [1, 3]
    idx = [0, 2]
    for t in [np.int32, np.int64, np.float, np.double]:
      self._testListDiff(x, y, out, idx, dtype=t)

  def testDuplicates(self):
    x = [1, 2, 4, 3, 2, 3, 3, 1]
    y = [4, 2]
    out = [1, 3, 3, 3, 1]
    idx = [0, 3, 5, 6, 7]
    for t in [np.int32, np.int64, np.float, np.double]:
      self._testListDiff(x, y, out, idx, dtype=t)

  def testRandom(self):
    num_random_tests = 10
    int_low = -7
    int_high = 8
    max_size = 50
    for _ in xrange(num_random_tests):
      x_size = np.random.randint(max_size + 1)
      x = np.random.randint(int_low, int_high, size=x_size)
      y_size = np.random.randint(max_size + 1)
      y = np.random.randint(int_low, int_high, size=y_size)
      out_idx = [(entry, pos) for pos, entry in enumerate(x) if entry not in y]
      if out_idx:
        out_idx = map(list, zip(*out_idx))
        out = out_idx[0]
        idx = out_idx[1]
      else:
        out = []
        idx = []
      for t in [np.int32, np.int64, np.float, np.double]:
        self._testListDiff(x, y, out, idx, dtype=t)

  def testInt32FullyOverlapping(self):
    x = [1, 2, 3, 4]
    y = [1, 2, 3, 4]
    out = []
    idx = []
    self._testListDiff(x, y, out, idx)

  def testInt32NonOverlapping(self):
    x = [1, 2, 3, 4]
    y = [5, 6]
    out = x
    idx = range(len(x))
    self._testListDiff(x, y, out, idx)

  def testInt32EmptyX(self):
    x = []
    y = [1, 2]
    out = []
    idx = []
    self._testListDiff(x, y, out, idx)

  def testInt32EmptyY(self):
    x = [1, 2, 3, 4]
    y = []
    out = x
    idx = range(len(x))
    self._testListDiff(x, y, out, idx)

  def testInt32EmptyXY(self):
    x = []
    y = []
    out = []
    idx = []
    self._testListDiff(x, y, out, idx)

if __name__ == "__main__":
  tf.test.main()