diff options
Diffstat (limited to 'tensorflow/python/kernel_tests/decode_csv_op_test.py')
-rw-r--r-- | tensorflow/python/kernel_tests/decode_csv_op_test.py | 148 |
1 files changed, 148 insertions, 0 deletions
diff --git a/tensorflow/python/kernel_tests/decode_csv_op_test.py b/tensorflow/python/kernel_tests/decode_csv_op_test.py new file mode 100644 index 0000000000..ae0917f8c4 --- /dev/null +++ b/tensorflow/python/kernel_tests/decode_csv_op_test.py @@ -0,0 +1,148 @@ +"""Tests for DecodeCSV op from parsing_ops.""" + +import tensorflow.python.platform + +import numpy as np +import tensorflow as tf + + +class DecodeCSVOpTest(tf.test.TestCase): + + def _test(self, args, expected_out=None, expected_err_re=None): + with self.test_session() as sess: + decode = tf.decode_csv(**args) + + if expected_err_re is None: + out = sess.run(decode) + + for i, field in enumerate(out): + if field.dtype == np.float32: + self.assertAllClose(field, expected_out[i]) + else: + self.assertAllEqual(field, expected_out[i]) + + else: + with self.assertRaisesOpError(expected_err_re): + sess.run(decode) + + def testSimple(self): + args = {"records": ["1", "2", '"3"'], "record_defaults": [[1]],} + + expected_out = [[1, 2, 3]] + + self._test(args, expected_out) + + def testScalar(self): + args = {"records": '1,""', "record_defaults": [[3], [4]]} + + expected_out = [1, 4] + + self._test(args, expected_out) + + def test2D(self): + args = {"records": [["1", "2"], ['""', "4"]], "record_defaults": [[5]]} + expected_out = [[[1, 2], [5, 4]]] + + self._test(args, expected_out) + + def testInt64(self): + args = { + "records": ["1", "2", '"2147483648"'], + "record_defaults": [np.array([], + dtype=np.int64)], + } + + expected_out = [[1, 2, 2147483648]] + + self._test(args, expected_out) + + def testComplexString(self): + args = { + "records": ['"1.0"', '"ab , c"', '"a\nbc"', '"ab""c"', " abc "], + "record_defaults": [["1"]] + } + + expected_out = [["1.0", "ab , c", "a\nbc", 'ab"c', " abc "]] + + self._test(args, expected_out) + + def testMultiRecords(self): + args = { + "records": ["1.0,4,aa", "0.2,5,bb", "3,6,cc"], + "record_defaults": [[1.0], [1], ["aa"]] + } + + expected_out = [[1.0, 0.2, 3], [4, 5, 6], ["aa", "bb", "cc"]] + + self._test(args, expected_out) + + def testWithDefaults(self): + args = { + "records": [",1,", "0.2,3,bcd", "3.0,,"], + "record_defaults": [[1.0], [0], ["a"]] + } + + expected_out = [[1.0, 0.2, 3.0], [1, 3, 0], ["a", "bcd", "a"]] + + self._test(args, expected_out) + + def testWithTabDelim(self): + args = { + "records": ["1\t1", "0.2\t3", "3.0\t"], + "record_defaults": [[1.0], [0]], + "field_delim": "\t" + } + + expected_out = [[1.0, 0.2, 3.0], [1, 3, 0]] + + self._test(args, expected_out) + + def testWithoutDefaultsError(self): + args = { + "records": [",1", "0.2,3", "3.0,"], + "record_defaults": [[1.0], np.array([], + dtype=np.int32)] + } + + self._test(args, + expected_err_re="Field 1 is required but missing in record 2!") + + def testWrongFieldIntError(self): + args = { + "records": [",1", "0.2,234a", "3.0,2"], + "record_defaults": [[1.0], np.array([], + dtype=np.int32)] + } + + self._test(args, + expected_err_re="Field 1 in record 1 is not a valid int32: 234a") + + def testOutOfRangeError(self): + args = { + "records": ["1", "9999999999999999999999999", "3"], + "record_defaults": [[1]] + } + + self._test(args, + expected_err_re="Field 0 in record 1 is not a valid int32: ") + + def testWrongFieldFloatError(self): + args = { + "records": [",1", "0.2,2", "3.0adf,3"], + "record_defaults": [[1.0], np.array([], + dtype=np.int32)] + } + + self._test(args, + expected_err_re="Field 0 in record 2 is not a valid float: ") + + def testWrongFieldStringError(self): + args = {"records": ['"1,a,"', "0.22", 'a"bc'], "record_defaults": [["a"]]} + + self._test( + args, + expected_err_re="Unquoted fields cannot have quotes/CRLFs inside") + + +if __name__ == "__main__": + tf.test.main() |