diff options
-rw-r--r-- | tensorflow/contrib/cmake/CMakeLists.txt | 3 | ||||
-rw-r--r-- | tensorflow/contrib/cmake/external/sqlite.cmake | 1 | ||||
-rw-r--r-- | tensorflow/contrib/cmake/tf_core_framework.cmake | 5 | ||||
-rw-r--r-- | tensorflow/contrib/data/python/kernel_tests/BUILD | 1 | ||||
-rw-r--r-- | tensorflow/contrib/summary/BUILD | 1 | ||||
-rw-r--r-- | tensorflow/contrib/tensorboard/db/BUILD | 1 | ||||
-rw-r--r-- | tensorflow/contrib/tensorboard/db/summary_db_writer.cc | 20 | ||||
-rw-r--r-- | tensorflow/core/lib/db/BUILD | 15 | ||||
-rw-r--r-- | tensorflow/core/lib/db/snapfn.cc | 253 | ||||
-rw-r--r-- | tensorflow/core/lib/db/sqlite.cc | 3 | ||||
-rw-r--r-- | tensorflow/core/lib/db/sqlite_test.cc | 17 | ||||
-rw-r--r-- | tensorflow/workspace.bzl | 2 | ||||
-rw-r--r-- | third_party/sqlite.BUILD | 61 |
13 files changed, 358 insertions, 25 deletions
diff --git a/tensorflow/contrib/cmake/CMakeLists.txt b/tensorflow/contrib/cmake/CMakeLists.txt index 8d023cc81d..2c4b7486d5 100644 --- a/tensorflow/contrib/cmake/CMakeLists.txt +++ b/tensorflow/contrib/cmake/CMakeLists.txt @@ -106,6 +106,9 @@ else() set(CMAKE_POSITION_INDEPENDENT_CODE OFF) endif() +# TODO(jart): We should make this only apply to snapfn.cc +add_definitions(-DSQLITE_OMIT_LOAD_EXTENSION) + if (tensorflow_DISABLE_EIGEN_FORCEINLINE) add_definitions(-DEIGEN_STRONG_INLINE=inline) endif() diff --git a/tensorflow/contrib/cmake/external/sqlite.cmake b/tensorflow/contrib/cmake/external/sqlite.cmake index 785039a469..14d8148e6e 100644 --- a/tensorflow/contrib/cmake/external/sqlite.cmake +++ b/tensorflow/contrib/cmake/external/sqlite.cmake @@ -28,6 +28,7 @@ endif() set(sqlite_HEADERS "${sqlite_BUILD}/sqlite3.h" + "${sqlite_BUILD}/sqlite3ext.h" ) if (WIN32) diff --git a/tensorflow/contrib/cmake/tf_core_framework.cmake b/tensorflow/contrib/cmake/tf_core_framework.cmake index 5ec1a8d04f..08f015445a 100644 --- a/tensorflow/contrib/cmake/tf_core_framework.cmake +++ b/tensorflow/contrib/cmake/tf_core_framework.cmake @@ -319,6 +319,11 @@ file(GLOB_RECURSE tf_core_framework_exclude_srcs "${tensorflow_source_dir}/tensorflow/contrib/tensorboard/db/*test*.cc" ) +# TODO(jart): Why doesn't this work? +# set_source_files_properties( +# ${tensorflow_source_dir}/tensorflow/contrib/tensorboard/db/snapfn.cc +# PROPERTIES COMPILE_FLAGS -DSQLITE_OMIT_LOAD_EXTENSION) + list(REMOVE_ITEM tf_core_framework_srcs ${tf_core_framework_exclude_srcs}) add_library(tf_core_framework OBJECT diff --git a/tensorflow/contrib/data/python/kernel_tests/BUILD b/tensorflow/contrib/data/python/kernel_tests/BUILD index ea43357e48..9051677929 100644 --- a/tensorflow/contrib/data/python/kernel_tests/BUILD +++ b/tensorflow/contrib/data/python/kernel_tests/BUILD @@ -467,6 +467,7 @@ py_test( "//tensorflow/python:client_testlib", "//tensorflow/python:dtypes", "//tensorflow/python:errors", + "@org_sqlite//:python", ], ) diff --git a/tensorflow/contrib/summary/BUILD b/tensorflow/contrib/summary/BUILD index 5ee5f1ae76..b58c83fdaf 100644 --- a/tensorflow/contrib/summary/BUILD +++ b/tensorflow/contrib/summary/BUILD @@ -112,5 +112,6 @@ py_library( "//tensorflow/core:protos_all_py", "//tensorflow/python:lib", "//tensorflow/python:platform", + "@org_sqlite//:python", ], ) diff --git a/tensorflow/contrib/tensorboard/db/BUILD b/tensorflow/contrib/tensorboard/db/BUILD index 9d3d60c24d..3a3402c59b 100644 --- a/tensorflow/contrib/tensorboard/db/BUILD +++ b/tensorflow/contrib/tensorboard/db/BUILD @@ -19,6 +19,7 @@ cc_library( tf_cc_test( name = "schema_test", + size = "small", srcs = ["schema_test.cc"], deps = [ ":schema", diff --git a/tensorflow/contrib/tensorboard/db/summary_db_writer.cc b/tensorflow/contrib/tensorboard/db/summary_db_writer.cc index 04b9c8e457..8929605817 100644 --- a/tensorflow/contrib/tensorboard/db/summary_db_writer.cc +++ b/tensorflow/contrib/tensorboard/db/summary_db_writer.cc @@ -23,7 +23,6 @@ limitations under the License. #include "tensorflow/core/lib/random/random.h" #include "tensorflow/core/lib/strings/stringprintf.h" #include "tensorflow/core/platform/fingerprint.h" -#include "tensorflow/core/platform/snappy.h" #include "tensorflow/core/util/event.pb.h" namespace tensorflow { @@ -56,21 +55,11 @@ Status Serialize(const protobuf::MessageLite& proto, string* output) { return Status::OK(); } -Status Compress(const string& data, string* output) { - output->clear(); - if (!port::Snappy_Compress(data.data(), data.size(), output)) { - return errors::FailedPrecondition("TensorBase needs Snappy"); - } - return Status::OK(); -} - Status BindProto(SqliteStatement* stmt, int parameter, const protobuf::MessageLite& proto) { string serialized; TF_RETURN_IF_ERROR(Serialize(proto, &serialized)); - string compressed; - TF_RETURN_IF_ERROR(Compress(serialized, &compressed)); - stmt->BindBlob(parameter, compressed); + stmt->BindBlob(parameter, serialized); return Status::OK(); } @@ -78,7 +67,6 @@ Status BindTensor(SqliteStatement* stmt, int parameter, const Tensor& t) { // TODO(@jart): Make portable between little and big endian systems. // TODO(@jart): Use TensorChunks with minimal copying for big tensors. // TODO(@jart): Add field to indicate encoding. - // TODO(@jart): Allow crunch tool to re-compress with zlib instead. TensorProto p; t.AsProtoTensorContent(&p); return BindProto(stmt, parameter, p); @@ -250,7 +238,7 @@ class GraphSaver { Status SaveNodes() { auto insert = db_->Prepare(R"sql( INSERT INTO Nodes (graph_id, node_id, node_name, op, device, node_def) - VALUES (?, ?, ?, ?, ?, ?) + VALUES (?, ?, ?, ?, ?, snap(?)) )sql"); for (int node_id = 0; node_id < graph_->node_size(); ++node_id) { NodeDef* node = graph_->mutable_node(node_id); @@ -276,7 +264,7 @@ class GraphSaver { Status SaveGraph() { auto insert = db_->Prepare(R"sql( INSERT INTO Graphs (graph_id, inserted_time, graph_def) - VALUES (?, ?, ?) + VALUES (?, ?, snap(?)) )sql"); insert.BindInt(1, graph_id_); insert.BindDouble(2, GetWallTime(env_)); @@ -305,7 +293,7 @@ class RunWriter { user_name_{user_name}, insert_tensor_{db_->Prepare(R"sql( INSERT OR REPLACE INTO Tensors (tag_id, step, computed_time, tensor) - VALUES (?, ?, ?, ?) + VALUES (?, ?, ?, snap(?)) )sql")} {} ~RunWriter() { diff --git a/tensorflow/core/lib/db/BUILD b/tensorflow/core/lib/db/BUILD index 41b7af1b69..d98c7785d2 100644 --- a/tensorflow/core/lib/db/BUILD +++ b/tensorflow/core/lib/db/BUILD @@ -12,14 +12,27 @@ cc_library( srcs = ["sqlite.cc"], hdrs = ["sqlite.h"], deps = [ + ":snapfn", "//tensorflow/compiler/xla:statusor", "//tensorflow/core:lib", - "@sqlite_archive//:sqlite", + "@org_sqlite", + ], +) + +cc_library( + name = "snapfn", + srcs = ["snapfn.cc"], + copts = ["-DSQLITE_OMIT_LOAD_EXTENSION"], + linkstatic = 1, + deps = [ + "@org_sqlite", + "@snappy", ], ) tf_cc_test( name = "sqlite_test", + size = "small", srcs = ["sqlite_test.cc"], deps = [ ":sqlite", diff --git a/tensorflow/core/lib/db/snapfn.cc b/tensorflow/core/lib/db/snapfn.cc new file mode 100644 index 0000000000..4a659f41ed --- /dev/null +++ b/tensorflow/core/lib/db/snapfn.cc @@ -0,0 +1,253 @@ +/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +/// \brief SQLite extension for Snappy compression +/// +/// Snappy a compression library that trades ratio for speed, almost going a +/// tenth as fast as memcpy(). +/// +/// FUNCTIONS +/// +/// - snap(value: BLOB|TEXT) -> BLOB +/// - snap(value: NULL|INT|REAL) -> value +/// +/// Applies Snappy compression. If value is TEXT or BLOB, then it is +/// compressed and a BLOB is returned with a byte prepended to indicate the +/// original type. Other types are returned as-is. +/// +/// - unsnap(value: BLOB) -> TEXT|BLOB +/// - unsnap(value: TEXT) -> SQLITE_MISMATCH +/// - unsnap(value: NULL|INT|REAL) -> value +/// +/// Decompresses value created by snap(). If value is empty, then an empty +/// blob is returned. Otherwise the original type is restored from the first +/// byte and the remaining ones are decompressed. TEXT is not allowed as an +/// input type. Remaining types are returned as-is. +/// +/// PERFORMANCE CONSIDERATIONS +/// +/// These functions are deterministic. This means SQLite ≥3.8.3 will factor +/// them out of inner loops when constant arguments are provided. In SQLite +/// ≥3.15.0 they can be used in the WHERE clause of partial indexes. Currently +/// there is no support for common sub-expression elimination. +/// +/// SQLite environments that aren't universally UTF8 will work, but should +/// encounter superfluous charset transcodings; as this implementation encodes +/// only UTF8 TEXT for the sake of simplicity. Contributions are welcome that +/// register multiple sister functions for the various charsets, which use the +/// higher order bits of the type byte to indicate encoding. +/// +/// SUPPORT MATRIX +/// +/// - 3.20.0 (2016-05-18) What FOSS TensorFlow uses +/// - 3.13.0 (2016-05-18) What Google uses c. 2017-12 +/// - 3.8.2 (2013-12-06) Used by Ubuntu 14.04 +/// +/// MANUAL COMPILATION +/// +/// $ sudo apt-get install libsqlite3-dev libsnappy-dev +/// $ c++ -shared --std=c++11 -o libsnapfn.so -fPIC snapfn.cc -lsnappy +/// +/// $ sqlite3 +/// sqlite> .load libsnapfn.so +/// sqlite> select hex(snap('aaaaaaaaaaaaaaaaa')); +/// 031100613E0100 +/// sqlite> select unsnap(x'031100613E0100'); +/// aaaaaaaaaaaaaaaaa +/// +/// $ python +/// >>> import sqlite3 +/// >>> db = sqlite3.connect(':memory:') +/// >>> db.enable_load_extension(True) +/// >>> db.execute('select load_extension("libsnapfn.so")') +/// >>> db.enable_load_extension(False) +/// >>> db.execute('select hex(snap("aaaaaaaaaaaaaaaaa"))').fetchone()[0] +/// u'031100613E0100' + +#include "sqlite3ext.h" +#include "snappy.h" + +SQLITE_EXTENSION_INIT1 + +static void snap(sqlite3_context* ctx, int /*argc*/, sqlite3_value** argv) { + const char* data; + int type = sqlite3_value_type(argv[0]); + switch (type) { + case SQLITE_NULL: + return; + case SQLITE_INTEGER: + sqlite3_result_int64(ctx, sqlite3_value_int64(argv[0])); + return; + case SQLITE_FLOAT: + sqlite3_result_double(ctx, sqlite3_value_double(argv[0])); + return; + case SQLITE_BLOB: + data = reinterpret_cast<const char*>(sqlite3_value_blob(argv[0])); + break; + case SQLITE_TEXT: + data = reinterpret_cast<const char*>(sqlite3_value_text(argv[0])); + break; + default: + sqlite3_result_error(ctx, "snap() invalid type", -1); + sqlite3_result_error_code(ctx, SQLITE_MISMATCH); + return; + } + int size = sqlite3_value_bytes(argv[0]); + if (size <= 0) { + char result[] = {static_cast<char>(type)}; + sqlite3_result_blob(ctx, result, sizeof(result), SQLITE_TRANSIENT); + return; + } + size_t output_size = + snappy::MaxCompressedLength(static_cast<size_t>(size)) + 1; + if (output_size > + static_cast<size_t>(sqlite3_limit(sqlite3_context_db_handle(ctx), + SQLITE_LIMIT_LENGTH, -1))) { + sqlite3_result_error_toobig(ctx); + return; + } + auto output = + static_cast<char*>(sqlite3_malloc(static_cast<int>(output_size))); + if (output == nullptr) { + sqlite3_result_error_nomem(ctx); + return; + } + *output++ = static_cast<char>(type), --output_size; + snappy::RawCompress(data, static_cast<size_t>(size), output, &output_size); + sqlite3_result_blob(ctx, output - 1, static_cast<int>(output_size + 1), + sqlite3_free); +} + +static void unsnap(sqlite3_context* ctx, int /*argc*/, sqlite3_value** argv) { + int type = sqlite3_value_type(argv[0]); + switch (type) { + case SQLITE_NULL: + return; + case SQLITE_INTEGER: + sqlite3_result_int64(ctx, sqlite3_value_int64(argv[0])); + return; + case SQLITE_FLOAT: + sqlite3_result_double(ctx, sqlite3_value_double(argv[0])); + return; + case SQLITE_BLOB: + break; + default: + sqlite3_result_error(ctx, "unsnap() invalid type", -1); + sqlite3_result_error_code(ctx, SQLITE_MISMATCH); + return; + } + int size = sqlite3_value_bytes(argv[0]); + auto blob = reinterpret_cast<const char*>(sqlite3_value_blob(argv[0])); + if (size <= 0) { + sqlite3_result_zeroblob(ctx, 0); + return; + } + type = static_cast<int>(*blob++), --size; + if (type != SQLITE_BLOB && type != SQLITE_TEXT) { + sqlite3_result_error(ctx, "unsnap() first byte is invalid type", -1); + sqlite3_result_error_code(ctx, SQLITE_CORRUPT); + return; + } + if (size == 0) { + if (type == SQLITE_TEXT) { + sqlite3_result_text(ctx, "", 0, SQLITE_STATIC); + } else { + sqlite3_result_zeroblob(ctx, 0); + } + return; + } + size_t output_size; + if (!snappy::GetUncompressedLength(blob, static_cast<size_t>(size), + &output_size)) { + sqlite3_result_error(ctx, "snappy parse error", -1); + sqlite3_result_error_code(ctx, SQLITE_CORRUPT); + return; + } + if (output_size > + static_cast<size_t>(sqlite3_limit(sqlite3_context_db_handle(ctx), + SQLITE_LIMIT_LENGTH, -1))) { + sqlite3_result_error_toobig(ctx); + return; + } + auto output = + static_cast<char*>(sqlite3_malloc(static_cast<int>(output_size))); + if (output == nullptr) { + sqlite3_result_error_nomem(ctx); + return; + } + if (!snappy::RawUncompress(blob, static_cast<size_t>(size), output)) { + sqlite3_result_error(ctx, "snappy message corruption", -1); + sqlite3_result_error_code(ctx, SQLITE_CORRUPT); + sqlite3_free(output); + return; + } + if (type == SQLITE_TEXT) { + sqlite3_result_text(ctx, output, static_cast<int>(output_size), + sqlite3_free); + } else { + sqlite3_result_blob(ctx, output, static_cast<int>(output_size), + sqlite3_free); + } +} + +extern "C" { + +#ifndef SQLITE_DETERMINISTIC +#define SQLITE_DETERMINISTIC 0 +#endif + +#ifndef SQLITE_CALLBACK +#define SQLITE_CALLBACK +#endif + +SQLITE_CALLBACK int sqlite3_snapfn_init(sqlite3* db, const char** /*pzErrMsg*/, + const sqlite3_api_routines* pApi) { + SQLITE_EXTENSION_INIT2(pApi); + int rc; + + rc = sqlite3_create_function_v2( + db, + "snap", // zFunctionName + 1, // nArg + SQLITE_UTF8 | SQLITE_DETERMINISTIC, // eTextRep + nullptr, // pApp + snap, // xFunc + nullptr, // xStep + nullptr, // xFinal + nullptr // xDestroy + ); + if (rc != SQLITE_OK) { + return rc; + } + + rc = sqlite3_create_function_v2( + db, + "unsnap", // zFunctionName + 1, // nArg + SQLITE_UTF8 | SQLITE_DETERMINISTIC, // eTextRep + nullptr, // pApp + unsnap, // xFunc + nullptr, // xStep + nullptr, // xFinal + nullptr // xDestroy + ); + if (rc != SQLITE_OK) { + return rc; + } + + return SQLITE_OK; +} + +} // extern "C" diff --git a/tensorflow/core/lib/db/sqlite.cc b/tensorflow/core/lib/db/sqlite.cc index 23361e6431..b0a9e2f0d8 100644 --- a/tensorflow/core/lib/db/sqlite.cc +++ b/tensorflow/core/lib/db/sqlite.cc @@ -17,6 +17,8 @@ limitations under the License. #include "tensorflow/core/lib/io/record_reader.h" #include "tensorflow/core/platform/logging.h" +extern "C" int sqlite3_snapfn_init(sqlite3*, const char**, const void*); + namespace tensorflow { namespace { @@ -42,6 +44,7 @@ string ExecuteOrEmpty(Sqlite* db, const char* sql) { xla::StatusOr<std::shared_ptr<Sqlite>> Sqlite::Open(const string& uri) { sqlite3* sqlite = nullptr; TF_RETURN_IF_ERROR(MakeStatus(sqlite3_open(uri.c_str(), &sqlite))); + CHECK_EQ(SQLITE_OK, sqlite3_snapfn_init(sqlite, nullptr, nullptr)); Sqlite* db = new Sqlite(sqlite, uri); // This is the SQLite default since 2016. However it's good to set // this anyway, since we might get linked against an older version of diff --git a/tensorflow/core/lib/db/sqlite_test.cc b/tensorflow/core/lib/db/sqlite_test.cc index ba045274ad..29772b88ea 100644 --- a/tensorflow/core/lib/db/sqlite_test.cc +++ b/tensorflow/core/lib/db/sqlite_test.cc @@ -231,5 +231,22 @@ TEST_F(SqliteTest, BindFailed) { s.status().error_message().find("INSERT INTO T (a) VALUES (123)")); } +TEST_F(SqliteTest, SnappyExtension) { + auto stmt = db_->Prepare("SELECT UNSNAP(SNAP(?))"); + stmt.BindText(1, "hello"); + TF_ASSERT_OK(stmt.Step(&is_done_)); + EXPECT_FALSE(is_done_); + EXPECT_EQ("hello", stmt.ColumnString(0)); +} + +TEST_F(SqliteTest, SnappyBinaryCompatibility) { + auto stmt = db_->Prepare( + "SELECT UNSNAP(X'03207C746F6461792069732074686520656E64206F66207468652" + "072657075626C6963')"); + TF_ASSERT_OK(stmt.Step(&is_done_)); + EXPECT_FALSE(is_done_); + EXPECT_EQ("today is the end of the republic", stmt.ColumnString(0)); +} + } // namespace } // namespace tensorflow diff --git a/tensorflow/workspace.bzl b/tensorflow/workspace.bzl index 969f8cbe1f..b295e2bc1e 100644 --- a/tensorflow/workspace.bzl +++ b/tensorflow/workspace.bzl @@ -219,7 +219,7 @@ def tf_workspace(path_prefix="", tf_repo_name=""): ) tf_http_archive( - name = "sqlite_archive", + name = "org_sqlite", urls = [ "https://mirror.bazel.build/www.sqlite.org/2017/sqlite-amalgamation-3200000.zip", "http://www.sqlite.org/2017/sqlite-amalgamation-3200000.zip", diff --git a/third_party/sqlite.BUILD b/third_party/sqlite.BUILD index 9840d7b151..761838d194 100644 --- a/third_party/sqlite.BUILD +++ b/third_party/sqlite.BUILD @@ -1,16 +1,63 @@ # Description: -# Sqlite3 library. Provides utilities for interacting -# with sqlite3 databases. +# sqlite3 is a serverless SQL RDBMS. licenses(["unencumbered"]) # Public Domain -# exports_files(["LICENSE"]) +SQLITE_COPTS = [ + "-DHAVE_DECL_STRERROR_R=1", + "-DHAVE_STDINT_H=1", + "-DHAVE_INTTYPES_H=1", + "-D_FILE_OFFSET_BITS=64", + "-D_REENTRANT=1", +] + select({ + "@org_tensorflow//tensorflow:windows": [ + "-DSQLITE_MAX_TRIGGER_DEPTH=100", + ], + "@org_tensorflow//tensorflow:windows_msvc": [ + "-DSQLITE_MAX_TRIGGER_DEPTH=100", + ], + "@org_tensorflow//tensorflow:darwin": [ + "-DHAVE_GMTIME_R=1", + "-DHAVE_LOCALTIME_R=1", + "-DHAVE_USLEEP=1", + ], + "//conditions:default": [ + "-DHAVE_FDATASYNC=1", + "-DHAVE_GMTIME_R=1", + "-DHAVE_LOCALTIME_R=1", + "-DHAVE_POSIX_FALLOCATE=1", + "-DHAVE_USLEEP=1", + ], +}) +# Production build of SQLite library that's baked into TensorFlow. cc_library( - name = "sqlite", + name = "org_sqlite", srcs = ["sqlite3.c"], - hdrs = ["sqlite3.h"], - includes = ["."], - linkopts = ["-lm"], + hdrs = [ + "sqlite3.h", + "sqlite3ext.h", + ], + copts = SQLITE_COPTS, + defines = [ + # This gets rid of the bloat of deprecated functionality. It + # needs to be listed here instead of copts because it's actually + # referenced in the sqlite3.h file. + "SQLITE_OMIT_DEPRECATED", + ], + linkopts = select({ + "@org_tensorflow//tensorflow:windows_msvc": [], + "//conditions:default": [ + "-ldl", + "-lpthread", + ], + }), + visibility = ["//visibility:public"], +) + +# This is a Copybara sync helper for Google. +py_library( + name = "python", + srcs_version = "PY2AND3", visibility = ["//visibility:public"], ) |