aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/contrib/data/python/kernel_tests/sql_dataset_op_test.py
diff options
context:
space:
mode:
Diffstat (limited to 'tensorflow/contrib/data/python/kernel_tests/sql_dataset_op_test.py')
-rw-r--r--tensorflow/contrib/data/python/kernel_tests/sql_dataset_op_test.py28
1 files changed, 27 insertions, 1 deletions
diff --git a/tensorflow/contrib/data/python/kernel_tests/sql_dataset_op_test.py b/tensorflow/contrib/data/python/kernel_tests/sql_dataset_op_test.py
index e26cef8ec5..4148addf28 100644
--- a/tensorflow/contrib/data/python/kernel_tests/sql_dataset_op_test.py
+++ b/tensorflow/contrib/data/python/kernel_tests/sql_dataset_op_test.py
@@ -22,6 +22,7 @@ import os
import sqlite3
+from tensorflow.contrib.data.python.kernel_tests import dataset_serialization_test_base
from tensorflow.contrib.data.python.ops import readers
from tensorflow.python.framework import dtypes
from tensorflow.python.framework import errors
@@ -29,7 +30,7 @@ from tensorflow.python.ops import array_ops
from tensorflow.python.platform import test
-class SqlDatasetTest(test.TestCase):
+class SqlDatasetTestBase(test.TestCase):
def _createSqlDataset(self, output_types, num_repeats=1):
dataset = readers.SqlDataset(self.driver_name, self.data_source_name,
@@ -92,6 +93,9 @@ class SqlDatasetTest(test.TestCase):
conn.commit()
conn.close()
+
+class SqlDatasetTest(SqlDatasetTestBase):
+
# Test that SqlDataset can read from a database table.
def testReadResultSet(self):
init_op, get_next = self._createSqlDataset((dtypes.string, dtypes.string,
@@ -652,5 +656,27 @@ class SqlDatasetTest(test.TestCase):
sess.run(get_next)
+class SqlDatasetSerializationTest(
+ SqlDatasetTestBase,
+ dataset_serialization_test_base.DatasetSerializationTestBase):
+
+ def _build_dataset(self, num_repeats):
+ data_source_name = os.path.join(test.get_temp_dir(), "tftest.sqlite")
+ driver_name = array_ops.placeholder_with_default(
+ array_ops.constant("sqlite", dtypes.string), shape=[])
+ query = ("SELECT first_name, last_name, motto FROM students ORDER BY "
+ "first_name DESC")
+ output_types = (dtypes.string, dtypes.string, dtypes.string)
+ return readers.SqlDataset(driver_name, data_source_name, query,
+ output_types).repeat(num_repeats)
+
+ def testSQLSaveable(self):
+ num_repeats = 4
+ num_outputs = num_repeats * 2
+ self.run_core_tests(lambda: self._build_dataset(num_repeats),
+ lambda: self._build_dataset(num_repeats // 2),
+ num_outputs)
+
+
if __name__ == "__main__":
test.main()