aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/python/kernel_tests/listdiff_op_test.py
diff options
context:
space:
mode:
Diffstat (limited to 'tensorflow/python/kernel_tests/listdiff_op_test.py')
-rw-r--r--tensorflow/python/kernel_tests/listdiff_op_test.py56
1 files changed, 27 insertions, 29 deletions
diff --git a/tensorflow/python/kernel_tests/listdiff_op_test.py b/tensorflow/python/kernel_tests/listdiff_op_test.py
index 14d657c805..040d9eb3bd 100644
--- a/tensorflow/python/kernel_tests/listdiff_op_test.py
+++ b/tensorflow/python/kernel_tests/listdiff_op_test.py
@@ -10,57 +10,56 @@ import numpy as np
from six.moves import xrange # pylint: disable=redefined-builtin
import tensorflow as tf
+_TYPES = [tf.int32, tf.int64, tf.float32, tf.float64, tf.string]
+
class ListDiffTest(tf.test.TestCase):
- def _testListDiff(self, x, y, out, idx, dtype=np.int32):
- x = np.array(x, dtype=dtype)
- y = np.array(y, dtype=dtype)
- out = np.array(out, dtype=dtype)
- idx = np.array(idx, dtype=dtype)
+ def _testListDiff(self, x, y, out, idx):
+ for dtype in _TYPES:
+ if dtype == tf.string:
+ x = [str(a) for a in x]
+ y = [str(a) for a in y]
+ out = [str(a) for a in out]
- with self.test_session() as sess:
- x_tensor = tf.convert_to_tensor(x)
- y_tensor = tf.convert_to_tensor(y)
- out_tensor, idx_tensor = tf.listdiff(x_tensor, y_tensor)
- tf_out, tf_idx = sess.run([out_tensor, idx_tensor])
+ with self.test_session() as sess:
+ x_tensor = tf.convert_to_tensor(x, dtype=dtype)
+ y_tensor = tf.convert_to_tensor(y, dtype=dtype)
+ out_tensor, idx_tensor = tf.listdiff(x_tensor, y_tensor)
+ tf_out, tf_idx = sess.run([out_tensor, idx_tensor])
- self.assertAllEqual(tf_out, out)
- self.assertAllEqual(tf_idx, idx)
- self.assertEqual(1, out_tensor.get_shape().ndims)
- self.assertEqual(1, idx_tensor.get_shape().ndims)
+ self.assertAllEqual(tf_out, out)
+ self.assertAllEqual(tf_idx, idx)
+ self.assertEqual(1, out_tensor.get_shape().ndims)
+ self.assertEqual(1, idx_tensor.get_shape().ndims)
def testBasic1(self):
x = [1, 2, 3, 4]
y = [1, 2]
out = [3, 4]
idx = [2, 3]
- for t in [np.int32, np.int64, np.float, np.double]:
- self._testListDiff(x, y, out, idx, dtype=t)
+ self._testListDiff(x, y, out, idx)
def testBasic2(self):
x = [1, 2, 3, 4]
y = [2]
out = [1, 3, 4]
idx = [0, 2, 3]
- for t in [np.int32, np.int64, np.float, np.double]:
- self._testListDiff(x, y, out, idx, dtype=t)
+ self._testListDiff(x, y, out, idx)
def testBasic3(self):
x = [1, 4, 3, 2]
y = [4, 2]
out = [1, 3]
idx = [0, 2]
- for t in [np.int32, np.int64, np.float, np.double]:
- self._testListDiff(x, y, out, idx, dtype=t)
+ self._testListDiff(x, y, out, idx)
def testDuplicates(self):
x = [1, 2, 4, 3, 2, 3, 3, 1]
y = [4, 2]
out = [1, 3, 3, 3, 1]
idx = [0, 3, 5, 6, 7]
- for t in [np.int32, np.int64, np.float, np.double]:
- self._testListDiff(x, y, out, idx, dtype=t)
+ self._testListDiff(x, y, out, idx)
def testRandom(self):
num_random_tests = 10
@@ -78,38 +77,37 @@ class ListDiffTest(tf.test.TestCase):
else:
out = []
idx = []
- for t in [np.int32, np.int64, np.float, np.double]:
- self._testListDiff(x, y, out, idx, dtype=t)
+ self._testListDiff(list(x), list(y), out, idx)
- def testInt32FullyOverlapping(self):
+ def testFullyOverlapping(self):
x = [1, 2, 3, 4]
y = [1, 2, 3, 4]
out = []
idx = []
self._testListDiff(x, y, out, idx)
- def testInt32NonOverlapping(self):
+ def testNonOverlapping(self):
x = [1, 2, 3, 4]
y = [5, 6]
out = x
idx = np.arange(len(x))
self._testListDiff(x, y, out, idx)
- def testInt32EmptyX(self):
+ def testEmptyX(self):
x = []
y = [1, 2]
out = []
idx = []
self._testListDiff(x, y, out, idx)
- def testInt32EmptyY(self):
+ def testEmptyY(self):
x = [1, 2, 3, 4]
y = []
out = x
idx = np.arange(len(x))
self._testListDiff(x, y, out, idx)
- def testInt32EmptyXY(self):
+ def testEmptyXY(self):
x = []
y = []
out = []