diff options
Diffstat (limited to 'tensorflow/python/kernel_tests/decode_csv_op_test.py')
-rw-r--r-- | tensorflow/python/kernel_tests/decode_csv_op_test.py | 55 |
1 files changed, 41 insertions, 14 deletions
diff --git a/tensorflow/python/kernel_tests/decode_csv_op_test.py b/tensorflow/python/kernel_tests/decode_csv_op_test.py index 4f49d72676..e9307a6b2f 100644 --- a/tensorflow/python/kernel_tests/decode_csv_op_test.py +++ b/tensorflow/python/kernel_tests/decode_csv_op_test.py @@ -20,28 +20,30 @@ from __future__ import print_function import numpy as np +from tensorflow.python.eager import context +from tensorflow.python.framework import errors +from tensorflow.python.framework import test_util from tensorflow.python.ops import parsing_ops from tensorflow.python.platform import test +@test_util.run_all_in_graph_and_eager_modes class DecodeCSVOpTest(test.TestCase): def _test(self, args, expected_out=None, expected_err_re=None): - with self.test_session() as sess: + if expected_err_re is None: decode = parsing_ops.decode_csv(**args) - - if expected_err_re is None: - out = sess.run(decode) - - for i, field in enumerate(out): - if field.dtype == np.float32 or field.dtype == np.float64: - self.assertAllClose(field, expected_out[i]) - else: - self.assertAllEqual(field, expected_out[i]) - - else: - with self.assertRaisesOpError(expected_err_re): - sess.run(decode) + out = self.evaluate(decode) + + for i, field in enumerate(out): + if field.dtype == np.float32 or field.dtype == np.float64: + self.assertAllClose(field, expected_out[i]) + else: + self.assertAllEqual(field, expected_out[i]) + else: + with self.assertRaisesOpError(expected_err_re): + decode = parsing_ops.decode_csv(**args) + self.evaluate(decode) def testSimple(self): args = { @@ -53,6 +55,31 @@ class DecodeCSVOpTest(test.TestCase): self._test(args, expected_out) + def testSimpleWithScalarDefaults(self): + args = { + "records": ["1,4", "2,5", "3,6"], + "record_defaults": [1, 2], + } + + expected_out = [[1, 2, 3], [4, 5, 6]] + + self._test(args, expected_out) + + def testSimpleWith2DDefaults(self): + args = { + "records": ["1", "2", "3"], + "record_defaults": [[[0]]], + } + + if context.executing_eagerly(): + err_spec = errors.InvalidArgumentError, ( + "Each record default should be at " + "most rank 1.") + else: + err_spec = ValueError, "Shape must be at most rank 1 but is rank 2" + with self.assertRaisesWithPredicateMatch(*err_spec): + self._test(args) + def testSimpleNoQuoteDelimiter(self): args = { "records": ["1", "2", '"3"'], |