aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/python/kernel_tests/decode_csv_op_test.py
diff options
context:
space:
mode:
Diffstat (limited to 'tensorflow/python/kernel_tests/decode_csv_op_test.py')
-rw-r--r--tensorflow/python/kernel_tests/decode_csv_op_test.py55
1 files changed, 41 insertions, 14 deletions
diff --git a/tensorflow/python/kernel_tests/decode_csv_op_test.py b/tensorflow/python/kernel_tests/decode_csv_op_test.py
index 4f49d72676..e9307a6b2f 100644
--- a/tensorflow/python/kernel_tests/decode_csv_op_test.py
+++ b/tensorflow/python/kernel_tests/decode_csv_op_test.py
@@ -20,28 +20,30 @@ from __future__ import print_function
import numpy as np
+from tensorflow.python.eager import context
+from tensorflow.python.framework import errors
+from tensorflow.python.framework import test_util
from tensorflow.python.ops import parsing_ops
from tensorflow.python.platform import test
+@test_util.run_all_in_graph_and_eager_modes
class DecodeCSVOpTest(test.TestCase):
def _test(self, args, expected_out=None, expected_err_re=None):
- with self.test_session() as sess:
+ if expected_err_re is None:
decode = parsing_ops.decode_csv(**args)
-
- if expected_err_re is None:
- out = sess.run(decode)
-
- for i, field in enumerate(out):
- if field.dtype == np.float32 or field.dtype == np.float64:
- self.assertAllClose(field, expected_out[i])
- else:
- self.assertAllEqual(field, expected_out[i])
-
- else:
- with self.assertRaisesOpError(expected_err_re):
- sess.run(decode)
+ out = self.evaluate(decode)
+
+ for i, field in enumerate(out):
+ if field.dtype == np.float32 or field.dtype == np.float64:
+ self.assertAllClose(field, expected_out[i])
+ else:
+ self.assertAllEqual(field, expected_out[i])
+ else:
+ with self.assertRaisesOpError(expected_err_re):
+ decode = parsing_ops.decode_csv(**args)
+ self.evaluate(decode)
def testSimple(self):
args = {
@@ -53,6 +55,31 @@ class DecodeCSVOpTest(test.TestCase):
self._test(args, expected_out)
+ def testSimpleWithScalarDefaults(self):
+ args = {
+ "records": ["1,4", "2,5", "3,6"],
+ "record_defaults": [1, 2],
+ }
+
+ expected_out = [[1, 2, 3], [4, 5, 6]]
+
+ self._test(args, expected_out)
+
+ def testSimpleWith2DDefaults(self):
+ args = {
+ "records": ["1", "2", "3"],
+ "record_defaults": [[[0]]],
+ }
+
+ if context.executing_eagerly():
+ err_spec = errors.InvalidArgumentError, (
+ "Each record default should be at "
+ "most rank 1.")
+ else:
+ err_spec = ValueError, "Shape must be at most rank 1 but is rank 2"
+ with self.assertRaisesWithPredicateMatch(*err_spec):
+ self._test(args)
+
def testSimpleNoQuoteDelimiter(self):
args = {
"records": ["1", "2", '"3"'],