aboutsummaryrefslogtreecommitdiffhomepage
diff options
context:
space:
mode:
authorGravatar A. Unique TensorFlower <gardener@tensorflow.org>2017-09-07 10:45:34 -0700
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2017-09-07 10:49:20 -0700
commita38d387323befd74671764f3187711332afdba5c (patch)
tree496596c5217f3e2dfc0fc4f314d590e4dbf149e4
parent75b095a318d4db84057d0f848445c34cacf0f5ab (diff)
Added functionality to allow SqlDataset to interpret a database column as a
`dtypes.int32`. Previously, database columns had to be interpreted as `dt.string`. Support for other TensorFlow types is forthcoming. PiperOrigin-RevId: 167880080
-rw-r--r--tensorflow/contrib/data/python/kernel_tests/sql_dataset_op_test.py210
-rw-r--r--tensorflow/core/kernels/sql/sqlite_query_connection.cc31
-rw-r--r--tensorflow/core/kernels/sql/sqlite_query_connection.h4
-rw-r--r--tensorflow/core/kernels/sql_dataset_ops.cc25
4 files changed, 211 insertions, 59 deletions
diff --git a/tensorflow/contrib/data/python/kernel_tests/sql_dataset_op_test.py b/tensorflow/contrib/data/python/kernel_tests/sql_dataset_op_test.py
index e520bc05d6..808d25c8c7 100644
--- a/tensorflow/contrib/data/python/kernel_tests/sql_dataset_op_test.py
+++ b/tensorflow/contrib/data/python/kernel_tests/sql_dataset_op_test.py
@@ -30,11 +30,20 @@ from tensorflow.python.platform import test
class SqlDatasetTest(test.TestCase):
+ def _createSqlDataset(self, output_types, num_repeats=1):
+ dataset = dataset_ops.SqlDataset(self.driver_name, self.data_source_name,
+ self.query,
+ output_types).repeat(num_repeats)
+ iterator = dataset.make_initializable_iterator()
+ init_op = iterator.initializer
+ get_next = iterator.get_next()
+ return init_op, get_next
+
def setUp(self):
self.data_source_name = os.path.join(test.get_temp_dir(), "tftest.sqlite")
- self.driver_name = array_ops.placeholder(dtypes.string, shape=[])
+ self.driver_name = array_ops.placeholder_with_default(
+ array_ops.constant("sqlite", dtypes.string), shape=[])
self.query = array_ops.placeholder(dtypes.string, shape=[])
- self.output_types = (dtypes.string, dtypes.string, dtypes.string)
conn = sqlite3.connect(self.data_source_name)
c = conn.cursor()
@@ -42,48 +51,52 @@ class SqlDatasetTest(test.TestCase):
c.execute("DROP TABLE IF EXISTS people")
c.execute(
"CREATE TABLE IF NOT EXISTS students (id INTEGER NOT NULL PRIMARY KEY,"
- " first_name VARCHAR(100), last_name VARCHAR(100), motto VARCHAR(100))")
- c.execute(
- "INSERT INTO students (first_name, last_name, motto) VALUES ('John', "
- "'Doe', 'Hi!'), ('Apple', 'Orange', 'Hi again!')")
+ " first_name VARCHAR(100), last_name VARCHAR(100), motto VARCHAR(100),"
+ " school_id VARCHAR(100), favorite_nonsense_word VARCHAR(100), "
+ "grade_level INTEGER, income INTEGER, favorite_number INTEGER)")
+ c.executemany(
+ "INSERT INTO students (first_name, last_name, motto, school_id, "
+ "favorite_nonsense_word, grade_level, income, favorite_number) "
+ "VALUES (?, ?, ?, ?, ?, ?, ?, ?)",
+ [("John", "Doe", "Hi!", "123", "n\0nsense", 9, 0, 2147483647),
+ ("Jane", "Moe", "Hi again!", "1000", "nonsense\0", 11, -20000,
+ -2147483648)])
c.execute(
"CREATE TABLE IF NOT EXISTS people (id INTEGER NOT NULL PRIMARY KEY, "
"first_name VARCHAR(100), last_name VARCHAR(100), state VARCHAR(100))")
- c.execute(
- "INSERT INTO people (first_name, last_name, state) VALUES ('Benjamin',"
- " 'Franklin', 'Pennsylvania'), ('John', 'Doe', 'California')")
+ c.executemany(
+ "INSERT INTO people (first_name, last_name, state) VALUES (?, ?, ?)",
+ [("Benjamin", "Franklin", "Pennsylvania"), ("John", "Doe",
+ "California")])
conn.commit()
conn.close()
- dataset = dataset_ops.SqlDataset(self.driver_name, self.data_source_name,
- self.query, self.output_types).repeat(2)
- iterator = dataset.make_initializable_iterator()
- self.init_op = iterator.initializer
- self.get_next = iterator.get_next()
-
# Test that SqlDataset can read from a database table.
def testReadResultSet(self):
+ init_op, get_next = self._createSqlDataset((dtypes.string, dtypes.string,
+ dtypes.string), 2)
with self.test_session() as sess:
for _ in range(2): # Run twice to verify statelessness of db operations.
sess.run(
- self.init_op,
+ init_op,
feed_dict={
self.driver_name: "sqlite",
self.query: "SELECT first_name, last_name, motto FROM students "
"ORDER BY first_name DESC"
})
for _ in range(2): # Dataset is repeated. See setUp.
- self.assertEqual((b"John", b"Doe", b"Hi!"), sess.run(self.get_next))
- self.assertEqual((b"Apple", b"Orange", b"Hi again!"),
- sess.run(self.get_next))
+ self.assertEqual((b"John", b"Doe", b"Hi!"), sess.run(get_next))
+ self.assertEqual((b"Jane", b"Moe", b"Hi again!"), sess.run(get_next))
with self.assertRaises(errors.OutOfRangeError):
- sess.run(self.get_next)
+ sess.run(get_next)
# Test that SqlDataset works on a join query.
def testReadResultSetJoinQuery(self):
+ init_op, get_next = self._createSqlDataset((dtypes.string, dtypes.string,
+ dtypes.string))
with self.test_session() as sess:
sess.run(
- self.init_op,
+ init_op,
feed_dict={
self.driver_name: "sqlite",
self.query:
@@ -92,32 +105,81 @@ class SqlDatasetTest(test.TestCase):
"ON students.first_name = people.first_name "
"AND students.last_name = people.last_name"
})
- for _ in range(2):
- self.assertEqual((b"John", b"California", b"Hi!"),
- sess.run(self.get_next))
+ self.assertEqual((b"John", b"California", b"Hi!"), sess.run(get_next))
with self.assertRaises(errors.OutOfRangeError):
- sess.run(self.get_next)
+ sess.run(get_next)
- # Test that an `OutOfRangeError` is raised on the first call to `get_next`
- # if result set is empty.
- def testReadEmptyResultSet(self):
+ # Test that SqlDataset can read a database entry with a null-terminator
+ # in the middle of the text and place the entry in a `string` tensor.
+ def testReadResultSetNullTerminator(self):
+ init_op, get_next = self._createSqlDataset((dtypes.string, dtypes.string,
+ dtypes.string))
with self.test_session() as sess:
sess.run(
- self.init_op,
+ init_op,
feed_dict={
self.driver_name: "sqlite",
+ self.query:
+ "SELECT first_name, last_name, favorite_nonsense_word "
+ "FROM students ORDER BY first_name DESC"
+ })
+ self.assertEqual((b"John", b"Doe", b"n\0nsense"), sess.run(get_next))
+ self.assertEqual((b"Jane", b"Moe", b"nonsense\0"), sess.run(get_next))
+ with self.assertRaises(errors.OutOfRangeError):
+ sess.run(get_next)
+
+ # Test that SqlDataset works when used on two different queries.
+ # Because the output types of the dataset must be determined at graph-creation
+ # time, the two queries must have the same number and types of columns.
+ def testReadResultSetReuseSqlDataset(self):
+ init_op, get_next = self._createSqlDataset((dtypes.string, dtypes.string,
+ dtypes.string))
+ with self.test_session() as sess:
+ sess.run(
+ init_op,
+ feed_dict={
+ self.query: "SELECT first_name, last_name, motto FROM students "
+ "ORDER BY first_name DESC"
+ })
+ self.assertEqual((b"John", b"Doe", b"Hi!"), sess.run(get_next))
+ self.assertEqual((b"Jane", b"Moe", b"Hi again!"), sess.run(get_next))
+ with self.assertRaises(errors.OutOfRangeError):
+ sess.run(get_next)
+ sess.run(
+ init_op,
+ feed_dict={
+ self.query: "SELECT first_name, last_name, state FROM people "
+ "ORDER BY first_name DESC"
+ })
+ self.assertEqual((b"John", b"Doe", b"California"), sess.run(get_next))
+ self.assertEqual((b"Benjamin", b"Franklin", b"Pennsylvania"),
+ sess.run(get_next))
+ with self.assertRaises(errors.OutOfRangeError):
+ sess.run(get_next)
+
+ # Test that an `OutOfRangeError` is raised on the first call to
+ # `get_next_str_only` if result set is empty.
+ def testReadEmptyResultSet(self):
+ init_op, get_next = self._createSqlDataset((dtypes.string, dtypes.string,
+ dtypes.string))
+ with self.test_session() as sess:
+ sess.run(
+ init_op,
+ feed_dict={
self.query: "SELECT first_name, last_name, motto FROM students "
"WHERE first_name = 'Nonexistent'"
})
with self.assertRaises(errors.OutOfRangeError):
- sess.run(self.get_next)
+ sess.run(get_next)
# Test that an error is raised when `driver_name` is invalid.
def testReadResultSetWithInvalidDriverName(self):
+ init_op = self._createSqlDataset((dtypes.string, dtypes.string,
+ dtypes.string))[0]
with self.test_session() as sess:
with self.assertRaises(errors.InvalidArgumentError):
sess.run(
- self.init_op,
+ init_op,
feed_dict={
self.driver_name: "sqlfake",
self.query: "SELECT first_name, last_name, motto FROM students "
@@ -126,62 +188,124 @@ class SqlDatasetTest(test.TestCase):
# Test that an error is raised when a column name in `query` is nonexistent
def testReadResultSetWithInvalidColumnName(self):
+ init_op, get_next = self._createSqlDataset((dtypes.string, dtypes.string,
+ dtypes.string))
with self.test_session() as sess:
sess.run(
- self.init_op,
+ init_op,
feed_dict={
- self.driver_name: "sqlite",
self.query:
"SELECT first_name, last_name, fake_column FROM students "
"ORDER BY first_name DESC"
})
with self.assertRaises(errors.UnknownError):
- sess.run(self.get_next)
+ sess.run(get_next)
# Test that an error is raised when there is a syntax error in `query`.
def testReadResultSetOfQueryWithSyntaxError(self):
+ init_op, get_next = self._createSqlDataset((dtypes.string, dtypes.string,
+ dtypes.string))
with self.test_session() as sess:
sess.run(
- self.init_op,
+ init_op,
feed_dict={
- self.driver_name: "sqlite",
self.query:
"SELEmispellECT first_name, last_name, motto FROM students "
"ORDER BY first_name DESC"
})
with self.assertRaises(errors.UnknownError):
- sess.run(self.get_next)
+ sess.run(get_next)
# Test that an error is raised when the number of columns in `query`
# does not match the length of `output_types`.
def testReadResultSetWithMismatchBetweenColumnsAndOutputTypes(self):
+ init_op, get_next = self._createSqlDataset((dtypes.string, dtypes.string,
+ dtypes.string))
with self.test_session() as sess:
sess.run(
- self.init_op,
+ init_op,
feed_dict={
- self.driver_name: "sqlite",
self.query: "SELECT first_name, last_name FROM students "
"ORDER BY first_name DESC"
})
with self.assertRaises(errors.InvalidArgumentError):
- sess.run(self.get_next)
+ sess.run(get_next)
# Test that no results are returned when `query` is an insert query rather
# than a select query. In particular, the error refers to the number of
# output types passed to the op not matching the number of columns in the
# result set of the query (namely, 0 for an insert statement.)
def testReadResultSetOfInsertQuery(self):
+ init_op, get_next = self._createSqlDataset((dtypes.string, dtypes.string,
+ dtypes.string))
with self.test_session() as sess:
sess.run(
- self.init_op,
+ init_op,
feed_dict={
- self.driver_name: "sqlite",
self.query:
"INSERT INTO students (first_name, last_name, motto) "
"VALUES ('Foo', 'Bar', 'Baz'), ('Fizz', 'Buzz', 'Fizzbuzz')"
})
with self.assertRaises(errors.InvalidArgumentError):
- sess.run(self.get_next)
+ sess.run(get_next)
+
+ def testReadResultSetInt32(self):
+ init_op, get_next = self._createSqlDataset((dtypes.string, dtypes.int32))
+ with self.test_session() as sess:
+ sess.run(
+ init_op,
+ feed_dict={
+ self.query: "SELECT first_name, grade_level FROM students "
+ "ORDER BY first_name DESC"
+ })
+ self.assertEqual((b"John", 9), sess.run(get_next))
+ self.assertEqual((b"Jane", 11), sess.run(get_next))
+ with self.assertRaises(errors.OutOfRangeError):
+ sess.run(get_next)
+
+ def testReadResultSetInt32NegativeAndZero(self):
+ init_op, get_next = self._createSqlDataset((dtypes.string, dtypes.int32))
+ with self.test_session() as sess:
+ sess.run(
+ init_op,
+ feed_dict={
+ self.query: "SELECT first_name, income FROM students "
+ "ORDER BY first_name DESC"
+ })
+ self.assertEqual((b"John", 0), sess.run(get_next))
+ self.assertEqual((b"Jane", -20000), sess.run(get_next))
+ with self.assertRaises(errors.OutOfRangeError):
+ sess.run(get_next)
+
+ def testReadResultSetInt32MaxValues(self):
+ init_op, get_next = self._createSqlDataset((dtypes.string, dtypes.int32))
+ with self.test_session() as sess:
+ sess.run(
+ init_op,
+ feed_dict={
+ self.query: "SELECT first_name, favorite_number FROM students "
+ "ORDER BY first_name DESC"
+ })
+ self.assertEqual((b"John", 2147483647), sess.run(get_next))
+ self.assertEqual((b"Jane", -2147483648), sess.run(get_next))
+ with self.assertRaises(errors.OutOfRangeError):
+ sess.run(get_next)
+
+ # Test that `SqlDataset` can read a numeric `varchar` from a SQLite database
+ # table and place it in an `int32` tensor.
+ def testReadResultSetInt32VarCharColumnAsInt(self):
+ init_op, get_next = self._createSqlDataset((dtypes.string, dtypes.int32))
+ with self.test_session() as sess:
+ sess.run(
+ init_op,
+ feed_dict={
+ self.query: "SELECT first_name, school_id FROM students "
+ "ORDER BY first_name DESC"
+ })
+ self.assertEqual((b"John", 123), sess.run(get_next))
+ self.assertEqual((b"Jane", 1000), sess.run(get_next))
+ with self.assertRaises(errors.OutOfRangeError):
+ sess.run(get_next)
if __name__ == "__main__":
diff --git a/tensorflow/core/kernels/sql/sqlite_query_connection.cc b/tensorflow/core/kernels/sql/sqlite_query_connection.cc
index 4bcf82ae28..b39b38b4b8 100644
--- a/tensorflow/core/kernels/sql/sqlite_query_connection.cc
+++ b/tensorflow/core/kernels/sql/sqlite_query_connection.cc
@@ -84,10 +84,9 @@ Status SqliteQueryConnection::GetNext(std::vector<Tensor>* out_tensors,
for (int i = 0; i < column_count_; i++) {
// TODO(b/64276939) Support other tensorflow types. Interpret columns as
// the types that the client specifies.
- Tensor tensor(cpu_allocator(), DT_STRING, {});
- string value(
- reinterpret_cast<const char*>(sqlite3_column_text(stmt_, i)));
- tensor.scalar<string>()() = value;
+ DataType dt = output_types_[i];
+ Tensor tensor(cpu_allocator(), dt, {});
+ FillTensorWithResultSetEntry(dt, i, &tensor);
out_tensors->emplace_back(std::move(tensor));
}
*end_of_sequence = false;
@@ -116,6 +115,30 @@ Status SqliteQueryConnection::ExecuteQuery() {
return s;
}
+void SqliteQueryConnection::FillTensorWithResultSetEntry(
+ const DataType& data_type, int column_index, Tensor* tensor) {
+ switch (data_type) {
+ case DT_STRING: {
+ const void* bytes = sqlite3_column_blob(stmt_, column_index);
+ int num_bytes = sqlite3_column_bytes(stmt_, column_index);
+ string value(reinterpret_cast<const char*>(bytes), num_bytes);
+ tensor->scalar<string>()() = value;
+ break;
+ }
+ case DT_INT32: {
+ int32 value = sqlite3_column_int(stmt_, column_index);
+ tensor->scalar<int32>()() = value;
+ break;
+ }
+ // Error preemptively thrown by SqlDatasetOp::MakeDataset in this case.
+ default: {
+ LOG(FATAL)
+ << "Use of unsupported TensorFlow data type by 'SqlQueryConnection': "
+ << DataTypeString(data_type) << ".";
+ }
+ }
+}
+
} // namespace sql
} // namespace tensorflow
diff --git a/tensorflow/core/kernels/sql/sqlite_query_connection.h b/tensorflow/core/kernels/sql/sqlite_query_connection.h
index f93b203a5b..917df37dc1 100644
--- a/tensorflow/core/kernels/sql/sqlite_query_connection.h
+++ b/tensorflow/core/kernels/sql/sqlite_query_connection.h
@@ -35,6 +35,10 @@ class SqliteQueryConnection : public QueryConnection {
private:
// Executes the query string `query_`.
Status ExecuteQuery();
+ // Fills `tensor` with the column_index_th element of the current row of
+ // `stmt_`.
+ void FillTensorWithResultSetEntry(const DataType& data_type, int column_index,
+ Tensor* tensor);
sqlite3* db_ = nullptr;
sqlite3_stmt* stmt_ = nullptr;
int column_count_ = 0;
diff --git a/tensorflow/core/kernels/sql_dataset_ops.cc b/tensorflow/core/kernels/sql_dataset_ops.cc
index d17ae53cb4..c8713f7996 100644
--- a/tensorflow/core/kernels/sql_dataset_ops.cc
+++ b/tensorflow/core/kernels/sql_dataset_ops.cc
@@ -34,6 +34,19 @@ class SqlDatasetOp : public DatasetOpKernel {
explicit SqlDatasetOp(OpKernelConstruction* ctx) : DatasetOpKernel(ctx) {
OP_REQUIRES_OK(ctx, ctx->GetAttr("output_types", &output_types_));
OP_REQUIRES_OK(ctx, ctx->GetAttr("output_shapes", &output_shapes_));
+ // TODO(b/64276939) Remove this check when we add support for other
+ // tensorflow types.
+ for (const DataType& dt : output_types_) {
+ OP_REQUIRES(
+ ctx, dt == DT_STRING || dt == DT_INT32,
+ errors::InvalidArgument(
+ "Each element of `output_types_` must be DT_STRING or DT_INT32"));
+ }
+ for (const PartialTensorShape& pts : output_shapes_) {
+ OP_REQUIRES(ctx, pts.dims() == 0,
+ errors::InvalidArgument(
+ "Each element of `output_shapes_` must be a scalar."));
+ }
}
void MakeDataset(OpKernelContext* ctx, DatasetBase** output) override {
string driver_name;
@@ -54,18 +67,6 @@ class SqlDatasetOp : public DatasetOpKernel {
"The database type, %s, is not supported by SqlDataset. "
"The set of supported databases is: {'sqlite'}.",
driver_name.c_str())));
- // TODO(b/64276939) Remove this check when we add support for other
- // tensorflow types.
- for (const DataType& dt : output_types_) {
- OP_REQUIRES(ctx, dt == DataType::DT_STRING,
- errors::InvalidArgument(
- "Each element of `output_types_` must be DT_STRING."));
- }
- for (const PartialTensorShape& pts : output_shapes_) {
- OP_REQUIRES(ctx, pts.dims() == 0,
- errors::InvalidArgument(
- "Each element of `output_shapes_` must be a scalar."));
- }
*output = new Dataset(driver_name, data_source_name, query, output_types_,
output_shapes_);