aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/python/kernel_tests/decode_csv_op_test.py
diff options
context:
space:
mode:
Diffstat (limited to 'tensorflow/python/kernel_tests/decode_csv_op_test.py')
-rw-r--r--tensorflow/python/kernel_tests/decode_csv_op_test.py148
1 files changed, 148 insertions, 0 deletions
diff --git a/tensorflow/python/kernel_tests/decode_csv_op_test.py b/tensorflow/python/kernel_tests/decode_csv_op_test.py
new file mode 100644
index 0000000000..ae0917f8c4
--- /dev/null
+++ b/tensorflow/python/kernel_tests/decode_csv_op_test.py
@@ -0,0 +1,148 @@
+"""Tests for DecodeCSV op from parsing_ops."""
+
+import tensorflow.python.platform
+
+import numpy as np
+import tensorflow as tf
+
+
+class DecodeCSVOpTest(tf.test.TestCase):
+
+ def _test(self, args, expected_out=None, expected_err_re=None):
+ with self.test_session() as sess:
+ decode = tf.decode_csv(**args)
+
+ if expected_err_re is None:
+ out = sess.run(decode)
+
+ for i, field in enumerate(out):
+ if field.dtype == np.float32:
+ self.assertAllClose(field, expected_out[i])
+ else:
+ self.assertAllEqual(field, expected_out[i])
+
+ else:
+ with self.assertRaisesOpError(expected_err_re):
+ sess.run(decode)
+
+ def testSimple(self):
+ args = {"records": ["1", "2", '"3"'], "record_defaults": [[1]],}
+
+ expected_out = [[1, 2, 3]]
+
+ self._test(args, expected_out)
+
+ def testScalar(self):
+ args = {"records": '1,""', "record_defaults": [[3], [4]]}
+
+ expected_out = [1, 4]
+
+ self._test(args, expected_out)
+
+ def test2D(self):
+ args = {"records": [["1", "2"], ['""', "4"]], "record_defaults": [[5]]}
+ expected_out = [[[1, 2], [5, 4]]]
+
+ self._test(args, expected_out)
+
+ def testInt64(self):
+ args = {
+ "records": ["1", "2", '"2147483648"'],
+ "record_defaults": [np.array([],
+ dtype=np.int64)],
+ }
+
+ expected_out = [[1, 2, 2147483648]]
+
+ self._test(args, expected_out)
+
+ def testComplexString(self):
+ args = {
+ "records": ['"1.0"', '"ab , c"', '"a\nbc"', '"ab""c"', " abc "],
+ "record_defaults": [["1"]]
+ }
+
+ expected_out = [["1.0", "ab , c", "a\nbc", 'ab"c', " abc "]]
+
+ self._test(args, expected_out)
+
+ def testMultiRecords(self):
+ args = {
+ "records": ["1.0,4,aa", "0.2,5,bb", "3,6,cc"],
+ "record_defaults": [[1.0], [1], ["aa"]]
+ }
+
+ expected_out = [[1.0, 0.2, 3], [4, 5, 6], ["aa", "bb", "cc"]]
+
+ self._test(args, expected_out)
+
+ def testWithDefaults(self):
+ args = {
+ "records": [",1,", "0.2,3,bcd", "3.0,,"],
+ "record_defaults": [[1.0], [0], ["a"]]
+ }
+
+ expected_out = [[1.0, 0.2, 3.0], [1, 3, 0], ["a", "bcd", "a"]]
+
+ self._test(args, expected_out)
+
+ def testWithTabDelim(self):
+ args = {
+ "records": ["1\t1", "0.2\t3", "3.0\t"],
+ "record_defaults": [[1.0], [0]],
+ "field_delim": "\t"
+ }
+
+ expected_out = [[1.0, 0.2, 3.0], [1, 3, 0]]
+
+ self._test(args, expected_out)
+
+ def testWithoutDefaultsError(self):
+ args = {
+ "records": [",1", "0.2,3", "3.0,"],
+ "record_defaults": [[1.0], np.array([],
+ dtype=np.int32)]
+ }
+
+ self._test(args,
+ expected_err_re="Field 1 is required but missing in record 2!")
+
+ def testWrongFieldIntError(self):
+ args = {
+ "records": [",1", "0.2,234a", "3.0,2"],
+ "record_defaults": [[1.0], np.array([],
+ dtype=np.int32)]
+ }
+
+ self._test(args,
+ expected_err_re="Field 1 in record 1 is not a valid int32: 234a")
+
+ def testOutOfRangeError(self):
+ args = {
+ "records": ["1", "9999999999999999999999999", "3"],
+ "record_defaults": [[1]]
+ }
+
+ self._test(args,
+ expected_err_re="Field 0 in record 1 is not a valid int32: ")
+
+ def testWrongFieldFloatError(self):
+ args = {
+ "records": [",1", "0.2,2", "3.0adf,3"],
+ "record_defaults": [[1.0], np.array([],
+ dtype=np.int32)]
+ }
+
+ self._test(args,
+ expected_err_re="Field 0 in record 2 is not a valid float: ")
+
+ def testWrongFieldStringError(self):
+ args = {"records": ['"1,a,"', "0.22", 'a"bc'], "record_defaults": [["a"]]}
+
+ self._test(
+ args,
+ expected_err_re="Unquoted fields cannot have quotes/CRLFs inside")
+
+
+if __name__ == "__main__":
+ tf.test.main()