aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/python/kernel_tests/listdiff_op_test.py
diff options
context:
space:
mode:
authorGravatar Vijay Vasudevan <vrv@google.com>2015-11-16 23:42:32 -0800
committerGravatar Vijay Vasudevan <vrv@google.com>2015-11-16 23:42:32 -0800
commit4213ac97be449d0e40631a314d2b7bd3901d4967 (patch)
treeb75b2fe8858068929e1bf0365f70cb14b80926ef /tensorflow/python/kernel_tests/listdiff_op_test.py
parent56313def004795f75ef8281a0294c958d28f1e06 (diff)
TensorFlow: conv improvements, label_image example, and
a few other changes. Changes: - Some improvements to convolution by using 32-bit indices by @benoitsteiner. Not all calls converted yet. Also some improvements to pooling as well by @benoitsteiner. - Improvements to sparse matmul CPU implementation by Ashish - Some fixes to warnings by @vrv - Doc fixes to padding by @Yangqing - Some improvements to Tensor wrappers by Eider - Speed up of matrix inverse on CPU by Rasmus - Add an example of doing image inference from a pre-trained model by @petewarden. - fixed formula in mnist example by nodir - Updates to event accumulator by Cassandra - Slight changes to tensor c api by @mrry - Handling of strings in listdiff by Phil - Fix negative fraction-of-queue-full stats by Frank - Type-checking improvement to importer by Yaroslav - logdir recursive search for Tensorboard by @danmane - Session.run() checks for empty graph by Manoj Base CL: 108013706
Diffstat (limited to 'tensorflow/python/kernel_tests/listdiff_op_test.py')
-rw-r--r--tensorflow/python/kernel_tests/listdiff_op_test.py56
1 files changed, 27 insertions, 29 deletions
diff --git a/tensorflow/python/kernel_tests/listdiff_op_test.py b/tensorflow/python/kernel_tests/listdiff_op_test.py
index 14d657c805..040d9eb3bd 100644
--- a/tensorflow/python/kernel_tests/listdiff_op_test.py
+++ b/tensorflow/python/kernel_tests/listdiff_op_test.py
@@ -10,57 +10,56 @@ import numpy as np
from six.moves import xrange # pylint: disable=redefined-builtin
import tensorflow as tf
+_TYPES = [tf.int32, tf.int64, tf.float32, tf.float64, tf.string]
+
class ListDiffTest(tf.test.TestCase):
- def _testListDiff(self, x, y, out, idx, dtype=np.int32):
- x = np.array(x, dtype=dtype)
- y = np.array(y, dtype=dtype)
- out = np.array(out, dtype=dtype)
- idx = np.array(idx, dtype=dtype)
+ def _testListDiff(self, x, y, out, idx):
+ for dtype in _TYPES:
+ if dtype == tf.string:
+ x = [str(a) for a in x]
+ y = [str(a) for a in y]
+ out = [str(a) for a in out]
- with self.test_session() as sess:
- x_tensor = tf.convert_to_tensor(x)
- y_tensor = tf.convert_to_tensor(y)
- out_tensor, idx_tensor = tf.listdiff(x_tensor, y_tensor)
- tf_out, tf_idx = sess.run([out_tensor, idx_tensor])
+ with self.test_session() as sess:
+ x_tensor = tf.convert_to_tensor(x, dtype=dtype)
+ y_tensor = tf.convert_to_tensor(y, dtype=dtype)
+ out_tensor, idx_tensor = tf.listdiff(x_tensor, y_tensor)
+ tf_out, tf_idx = sess.run([out_tensor, idx_tensor])
- self.assertAllEqual(tf_out, out)
- self.assertAllEqual(tf_idx, idx)
- self.assertEqual(1, out_tensor.get_shape().ndims)
- self.assertEqual(1, idx_tensor.get_shape().ndims)
+ self.assertAllEqual(tf_out, out)
+ self.assertAllEqual(tf_idx, idx)
+ self.assertEqual(1, out_tensor.get_shape().ndims)
+ self.assertEqual(1, idx_tensor.get_shape().ndims)
def testBasic1(self):
x = [1, 2, 3, 4]
y = [1, 2]
out = [3, 4]
idx = [2, 3]
- for t in [np.int32, np.int64, np.float, np.double]:
- self._testListDiff(x, y, out, idx, dtype=t)
+ self._testListDiff(x, y, out, idx)
def testBasic2(self):
x = [1, 2, 3, 4]
y = [2]
out = [1, 3, 4]
idx = [0, 2, 3]
- for t in [np.int32, np.int64, np.float, np.double]:
- self._testListDiff(x, y, out, idx, dtype=t)
+ self._testListDiff(x, y, out, idx)
def testBasic3(self):
x = [1, 4, 3, 2]
y = [4, 2]
out = [1, 3]
idx = [0, 2]
- for t in [np.int32, np.int64, np.float, np.double]:
- self._testListDiff(x, y, out, idx, dtype=t)
+ self._testListDiff(x, y, out, idx)
def testDuplicates(self):
x = [1, 2, 4, 3, 2, 3, 3, 1]
y = [4, 2]
out = [1, 3, 3, 3, 1]
idx = [0, 3, 5, 6, 7]
- for t in [np.int32, np.int64, np.float, np.double]:
- self._testListDiff(x, y, out, idx, dtype=t)
+ self._testListDiff(x, y, out, idx)
def testRandom(self):
num_random_tests = 10
@@ -78,38 +77,37 @@ class ListDiffTest(tf.test.TestCase):
else:
out = []
idx = []
- for t in [np.int32, np.int64, np.float, np.double]:
- self._testListDiff(x, y, out, idx, dtype=t)
+ self._testListDiff(list(x), list(y), out, idx)
- def testInt32FullyOverlapping(self):
+ def testFullyOverlapping(self):
x = [1, 2, 3, 4]
y = [1, 2, 3, 4]
out = []
idx = []
self._testListDiff(x, y, out, idx)
- def testInt32NonOverlapping(self):
+ def testNonOverlapping(self):
x = [1, 2, 3, 4]
y = [5, 6]
out = x
idx = np.arange(len(x))
self._testListDiff(x, y, out, idx)
- def testInt32EmptyX(self):
+ def testEmptyX(self):
x = []
y = [1, 2]
out = []
idx = []
self._testListDiff(x, y, out, idx)
- def testInt32EmptyY(self):
+ def testEmptyY(self):
x = [1, 2, 3, 4]
y = []
out = x
idx = np.arange(len(x))
self._testListDiff(x, y, out, idx)
- def testInt32EmptyXY(self):
+ def testEmptyXY(self):
x = []
y = []
out = []