diff options
author | 2017-09-07 10:45:34 -0700 | |
---|---|---|
committer | 2017-09-07 10:49:20 -0700 | |
commit | a38d387323befd74671764f3187711332afdba5c (patch) | |
tree | 496596c5217f3e2dfc0fc4f314d590e4dbf149e4 | |
parent | 75b095a318d4db84057d0f848445c34cacf0f5ab (diff) |
Added functionality to allow SqlDataset to interpret a database column as a
`dtypes.int32`.
Previously, database columns had to be interpreted as `dt.string`. Support for
other TensorFlow types is forthcoming.
PiperOrigin-RevId: 167880080
4 files changed, 211 insertions, 59 deletions
diff --git a/tensorflow/contrib/data/python/kernel_tests/sql_dataset_op_test.py b/tensorflow/contrib/data/python/kernel_tests/sql_dataset_op_test.py index e520bc05d6..808d25c8c7 100644 --- a/tensorflow/contrib/data/python/kernel_tests/sql_dataset_op_test.py +++ b/tensorflow/contrib/data/python/kernel_tests/sql_dataset_op_test.py @@ -30,11 +30,20 @@ from tensorflow.python.platform import test class SqlDatasetTest(test.TestCase): + def _createSqlDataset(self, output_types, num_repeats=1): + dataset = dataset_ops.SqlDataset(self.driver_name, self.data_source_name, + self.query, + output_types).repeat(num_repeats) + iterator = dataset.make_initializable_iterator() + init_op = iterator.initializer + get_next = iterator.get_next() + return init_op, get_next + def setUp(self): self.data_source_name = os.path.join(test.get_temp_dir(), "tftest.sqlite") - self.driver_name = array_ops.placeholder(dtypes.string, shape=[]) + self.driver_name = array_ops.placeholder_with_default( + array_ops.constant("sqlite", dtypes.string), shape=[]) self.query = array_ops.placeholder(dtypes.string, shape=[]) - self.output_types = (dtypes.string, dtypes.string, dtypes.string) conn = sqlite3.connect(self.data_source_name) c = conn.cursor() @@ -42,48 +51,52 @@ class SqlDatasetTest(test.TestCase): c.execute("DROP TABLE IF EXISTS people") c.execute( "CREATE TABLE IF NOT EXISTS students (id INTEGER NOT NULL PRIMARY KEY," - " first_name VARCHAR(100), last_name VARCHAR(100), motto VARCHAR(100))") - c.execute( - "INSERT INTO students (first_name, last_name, motto) VALUES ('John', " - "'Doe', 'Hi!'), ('Apple', 'Orange', 'Hi again!')") + " first_name VARCHAR(100), last_name VARCHAR(100), motto VARCHAR(100)," + " school_id VARCHAR(100), favorite_nonsense_word VARCHAR(100), " + "grade_level INTEGER, income INTEGER, favorite_number INTEGER)") + c.executemany( + "INSERT INTO students (first_name, last_name, motto, school_id, " + "favorite_nonsense_word, grade_level, income, favorite_number) " + "VALUES (?, ?, ?, ?, ?, ?, ?, ?)", + [("John", "Doe", "Hi!", "123", "n\0nsense", 9, 0, 2147483647), + ("Jane", "Moe", "Hi again!", "1000", "nonsense\0", 11, -20000, + -2147483648)]) c.execute( "CREATE TABLE IF NOT EXISTS people (id INTEGER NOT NULL PRIMARY KEY, " "first_name VARCHAR(100), last_name VARCHAR(100), state VARCHAR(100))") - c.execute( - "INSERT INTO people (first_name, last_name, state) VALUES ('Benjamin'," - " 'Franklin', 'Pennsylvania'), ('John', 'Doe', 'California')") + c.executemany( + "INSERT INTO people (first_name, last_name, state) VALUES (?, ?, ?)", + [("Benjamin", "Franklin", "Pennsylvania"), ("John", "Doe", + "California")]) conn.commit() conn.close() - dataset = dataset_ops.SqlDataset(self.driver_name, self.data_source_name, - self.query, self.output_types).repeat(2) - iterator = dataset.make_initializable_iterator() - self.init_op = iterator.initializer - self.get_next = iterator.get_next() - # Test that SqlDataset can read from a database table. def testReadResultSet(self): + init_op, get_next = self._createSqlDataset((dtypes.string, dtypes.string, + dtypes.string), 2) with self.test_session() as sess: for _ in range(2): # Run twice to verify statelessness of db operations. sess.run( - self.init_op, + init_op, feed_dict={ self.driver_name: "sqlite", self.query: "SELECT first_name, last_name, motto FROM students " "ORDER BY first_name DESC" }) for _ in range(2): # Dataset is repeated. See setUp. - self.assertEqual((b"John", b"Doe", b"Hi!"), sess.run(self.get_next)) - self.assertEqual((b"Apple", b"Orange", b"Hi again!"), - sess.run(self.get_next)) + self.assertEqual((b"John", b"Doe", b"Hi!"), sess.run(get_next)) + self.assertEqual((b"Jane", b"Moe", b"Hi again!"), sess.run(get_next)) with self.assertRaises(errors.OutOfRangeError): - sess.run(self.get_next) + sess.run(get_next) # Test that SqlDataset works on a join query. def testReadResultSetJoinQuery(self): + init_op, get_next = self._createSqlDataset((dtypes.string, dtypes.string, + dtypes.string)) with self.test_session() as sess: sess.run( - self.init_op, + init_op, feed_dict={ self.driver_name: "sqlite", self.query: @@ -92,32 +105,81 @@ class SqlDatasetTest(test.TestCase): "ON students.first_name = people.first_name " "AND students.last_name = people.last_name" }) - for _ in range(2): - self.assertEqual((b"John", b"California", b"Hi!"), - sess.run(self.get_next)) + self.assertEqual((b"John", b"California", b"Hi!"), sess.run(get_next)) with self.assertRaises(errors.OutOfRangeError): - sess.run(self.get_next) + sess.run(get_next) - # Test that an `OutOfRangeError` is raised on the first call to `get_next` - # if result set is empty. - def testReadEmptyResultSet(self): + # Test that SqlDataset can read a database entry with a null-terminator + # in the middle of the text and place the entry in a `string` tensor. + def testReadResultSetNullTerminator(self): + init_op, get_next = self._createSqlDataset((dtypes.string, dtypes.string, + dtypes.string)) with self.test_session() as sess: sess.run( - self.init_op, + init_op, feed_dict={ self.driver_name: "sqlite", + self.query: + "SELECT first_name, last_name, favorite_nonsense_word " + "FROM students ORDER BY first_name DESC" + }) + self.assertEqual((b"John", b"Doe", b"n\0nsense"), sess.run(get_next)) + self.assertEqual((b"Jane", b"Moe", b"nonsense\0"), sess.run(get_next)) + with self.assertRaises(errors.OutOfRangeError): + sess.run(get_next) + + # Test that SqlDataset works when used on two different queries. + # Because the output types of the dataset must be determined at graph-creation + # time, the two queries must have the same number and types of columns. + def testReadResultSetReuseSqlDataset(self): + init_op, get_next = self._createSqlDataset((dtypes.string, dtypes.string, + dtypes.string)) + with self.test_session() as sess: + sess.run( + init_op, + feed_dict={ + self.query: "SELECT first_name, last_name, motto FROM students " + "ORDER BY first_name DESC" + }) + self.assertEqual((b"John", b"Doe", b"Hi!"), sess.run(get_next)) + self.assertEqual((b"Jane", b"Moe", b"Hi again!"), sess.run(get_next)) + with self.assertRaises(errors.OutOfRangeError): + sess.run(get_next) + sess.run( + init_op, + feed_dict={ + self.query: "SELECT first_name, last_name, state FROM people " + "ORDER BY first_name DESC" + }) + self.assertEqual((b"John", b"Doe", b"California"), sess.run(get_next)) + self.assertEqual((b"Benjamin", b"Franklin", b"Pennsylvania"), + sess.run(get_next)) + with self.assertRaises(errors.OutOfRangeError): + sess.run(get_next) + + # Test that an `OutOfRangeError` is raised on the first call to + # `get_next_str_only` if result set is empty. + def testReadEmptyResultSet(self): + init_op, get_next = self._createSqlDataset((dtypes.string, dtypes.string, + dtypes.string)) + with self.test_session() as sess: + sess.run( + init_op, + feed_dict={ self.query: "SELECT first_name, last_name, motto FROM students " "WHERE first_name = 'Nonexistent'" }) with self.assertRaises(errors.OutOfRangeError): - sess.run(self.get_next) + sess.run(get_next) # Test that an error is raised when `driver_name` is invalid. def testReadResultSetWithInvalidDriverName(self): + init_op = self._createSqlDataset((dtypes.string, dtypes.string, + dtypes.string))[0] with self.test_session() as sess: with self.assertRaises(errors.InvalidArgumentError): sess.run( - self.init_op, + init_op, feed_dict={ self.driver_name: "sqlfake", self.query: "SELECT first_name, last_name, motto FROM students " @@ -126,62 +188,124 @@ class SqlDatasetTest(test.TestCase): # Test that an error is raised when a column name in `query` is nonexistent def testReadResultSetWithInvalidColumnName(self): + init_op, get_next = self._createSqlDataset((dtypes.string, dtypes.string, + dtypes.string)) with self.test_session() as sess: sess.run( - self.init_op, + init_op, feed_dict={ - self.driver_name: "sqlite", self.query: "SELECT first_name, last_name, fake_column FROM students " "ORDER BY first_name DESC" }) with self.assertRaises(errors.UnknownError): - sess.run(self.get_next) + sess.run(get_next) # Test that an error is raised when there is a syntax error in `query`. def testReadResultSetOfQueryWithSyntaxError(self): + init_op, get_next = self._createSqlDataset((dtypes.string, dtypes.string, + dtypes.string)) with self.test_session() as sess: sess.run( - self.init_op, + init_op, feed_dict={ - self.driver_name: "sqlite", self.query: "SELEmispellECT first_name, last_name, motto FROM students " "ORDER BY first_name DESC" }) with self.assertRaises(errors.UnknownError): - sess.run(self.get_next) + sess.run(get_next) # Test that an error is raised when the number of columns in `query` # does not match the length of `output_types`. def testReadResultSetWithMismatchBetweenColumnsAndOutputTypes(self): + init_op, get_next = self._createSqlDataset((dtypes.string, dtypes.string, + dtypes.string)) with self.test_session() as sess: sess.run( - self.init_op, + init_op, feed_dict={ - self.driver_name: "sqlite", self.query: "SELECT first_name, last_name FROM students " "ORDER BY first_name DESC" }) with self.assertRaises(errors.InvalidArgumentError): - sess.run(self.get_next) + sess.run(get_next) # Test that no results are returned when `query` is an insert query rather # than a select query. In particular, the error refers to the number of # output types passed to the op not matching the number of columns in the # result set of the query (namely, 0 for an insert statement.) def testReadResultSetOfInsertQuery(self): + init_op, get_next = self._createSqlDataset((dtypes.string, dtypes.string, + dtypes.string)) with self.test_session() as sess: sess.run( - self.init_op, + init_op, feed_dict={ - self.driver_name: "sqlite", self.query: "INSERT INTO students (first_name, last_name, motto) " "VALUES ('Foo', 'Bar', 'Baz'), ('Fizz', 'Buzz', 'Fizzbuzz')" }) with self.assertRaises(errors.InvalidArgumentError): - sess.run(self.get_next) + sess.run(get_next) + + def testReadResultSetInt32(self): + init_op, get_next = self._createSqlDataset((dtypes.string, dtypes.int32)) + with self.test_session() as sess: + sess.run( + init_op, + feed_dict={ + self.query: "SELECT first_name, grade_level FROM students " + "ORDER BY first_name DESC" + }) + self.assertEqual((b"John", 9), sess.run(get_next)) + self.assertEqual((b"Jane", 11), sess.run(get_next)) + with self.assertRaises(errors.OutOfRangeError): + sess.run(get_next) + + def testReadResultSetInt32NegativeAndZero(self): + init_op, get_next = self._createSqlDataset((dtypes.string, dtypes.int32)) + with self.test_session() as sess: + sess.run( + init_op, + feed_dict={ + self.query: "SELECT first_name, income FROM students " + "ORDER BY first_name DESC" + }) + self.assertEqual((b"John", 0), sess.run(get_next)) + self.assertEqual((b"Jane", -20000), sess.run(get_next)) + with self.assertRaises(errors.OutOfRangeError): + sess.run(get_next) + + def testReadResultSetInt32MaxValues(self): + init_op, get_next = self._createSqlDataset((dtypes.string, dtypes.int32)) + with self.test_session() as sess: + sess.run( + init_op, + feed_dict={ + self.query: "SELECT first_name, favorite_number FROM students " + "ORDER BY first_name DESC" + }) + self.assertEqual((b"John", 2147483647), sess.run(get_next)) + self.assertEqual((b"Jane", -2147483648), sess.run(get_next)) + with self.assertRaises(errors.OutOfRangeError): + sess.run(get_next) + + # Test that `SqlDataset` can read a numeric `varchar` from a SQLite database + # table and place it in an `int32` tensor. + def testReadResultSetInt32VarCharColumnAsInt(self): + init_op, get_next = self._createSqlDataset((dtypes.string, dtypes.int32)) + with self.test_session() as sess: + sess.run( + init_op, + feed_dict={ + self.query: "SELECT first_name, school_id FROM students " + "ORDER BY first_name DESC" + }) + self.assertEqual((b"John", 123), sess.run(get_next)) + self.assertEqual((b"Jane", 1000), sess.run(get_next)) + with self.assertRaises(errors.OutOfRangeError): + sess.run(get_next) if __name__ == "__main__": diff --git a/tensorflow/core/kernels/sql/sqlite_query_connection.cc b/tensorflow/core/kernels/sql/sqlite_query_connection.cc index 4bcf82ae28..b39b38b4b8 100644 --- a/tensorflow/core/kernels/sql/sqlite_query_connection.cc +++ b/tensorflow/core/kernels/sql/sqlite_query_connection.cc @@ -84,10 +84,9 @@ Status SqliteQueryConnection::GetNext(std::vector<Tensor>* out_tensors, for (int i = 0; i < column_count_; i++) { // TODO(b/64276939) Support other tensorflow types. Interpret columns as // the types that the client specifies. - Tensor tensor(cpu_allocator(), DT_STRING, {}); - string value( - reinterpret_cast<const char*>(sqlite3_column_text(stmt_, i))); - tensor.scalar<string>()() = value; + DataType dt = output_types_[i]; + Tensor tensor(cpu_allocator(), dt, {}); + FillTensorWithResultSetEntry(dt, i, &tensor); out_tensors->emplace_back(std::move(tensor)); } *end_of_sequence = false; @@ -116,6 +115,30 @@ Status SqliteQueryConnection::ExecuteQuery() { return s; } +void SqliteQueryConnection::FillTensorWithResultSetEntry( + const DataType& data_type, int column_index, Tensor* tensor) { + switch (data_type) { + case DT_STRING: { + const void* bytes = sqlite3_column_blob(stmt_, column_index); + int num_bytes = sqlite3_column_bytes(stmt_, column_index); + string value(reinterpret_cast<const char*>(bytes), num_bytes); + tensor->scalar<string>()() = value; + break; + } + case DT_INT32: { + int32 value = sqlite3_column_int(stmt_, column_index); + tensor->scalar<int32>()() = value; + break; + } + // Error preemptively thrown by SqlDatasetOp::MakeDataset in this case. + default: { + LOG(FATAL) + << "Use of unsupported TensorFlow data type by 'SqlQueryConnection': " + << DataTypeString(data_type) << "."; + } + } +} + } // namespace sql } // namespace tensorflow diff --git a/tensorflow/core/kernels/sql/sqlite_query_connection.h b/tensorflow/core/kernels/sql/sqlite_query_connection.h index f93b203a5b..917df37dc1 100644 --- a/tensorflow/core/kernels/sql/sqlite_query_connection.h +++ b/tensorflow/core/kernels/sql/sqlite_query_connection.h @@ -35,6 +35,10 @@ class SqliteQueryConnection : public QueryConnection { private: // Executes the query string `query_`. Status ExecuteQuery(); + // Fills `tensor` with the column_index_th element of the current row of + // `stmt_`. + void FillTensorWithResultSetEntry(const DataType& data_type, int column_index, + Tensor* tensor); sqlite3* db_ = nullptr; sqlite3_stmt* stmt_ = nullptr; int column_count_ = 0; diff --git a/tensorflow/core/kernels/sql_dataset_ops.cc b/tensorflow/core/kernels/sql_dataset_ops.cc index d17ae53cb4..c8713f7996 100644 --- a/tensorflow/core/kernels/sql_dataset_ops.cc +++ b/tensorflow/core/kernels/sql_dataset_ops.cc @@ -34,6 +34,19 @@ class SqlDatasetOp : public DatasetOpKernel { explicit SqlDatasetOp(OpKernelConstruction* ctx) : DatasetOpKernel(ctx) { OP_REQUIRES_OK(ctx, ctx->GetAttr("output_types", &output_types_)); OP_REQUIRES_OK(ctx, ctx->GetAttr("output_shapes", &output_shapes_)); + // TODO(b/64276939) Remove this check when we add support for other + // tensorflow types. + for (const DataType& dt : output_types_) { + OP_REQUIRES( + ctx, dt == DT_STRING || dt == DT_INT32, + errors::InvalidArgument( + "Each element of `output_types_` must be DT_STRING or DT_INT32")); + } + for (const PartialTensorShape& pts : output_shapes_) { + OP_REQUIRES(ctx, pts.dims() == 0, + errors::InvalidArgument( + "Each element of `output_shapes_` must be a scalar.")); + } } void MakeDataset(OpKernelContext* ctx, DatasetBase** output) override { string driver_name; @@ -54,18 +67,6 @@ class SqlDatasetOp : public DatasetOpKernel { "The database type, %s, is not supported by SqlDataset. " "The set of supported databases is: {'sqlite'}.", driver_name.c_str()))); - // TODO(b/64276939) Remove this check when we add support for other - // tensorflow types. - for (const DataType& dt : output_types_) { - OP_REQUIRES(ctx, dt == DataType::DT_STRING, - errors::InvalidArgument( - "Each element of `output_types_` must be DT_STRING.")); - } - for (const PartialTensorShape& pts : output_shapes_) { - OP_REQUIRES(ctx, pts.dims() == 0, - errors::InvalidArgument( - "Each element of `output_shapes_` must be a scalar.")); - } *output = new Dataset(driver_name, data_source_name, query, output_types_, output_shapes_); |