diff options
author | Justine Tunney <jart@google.com> | 2018-01-11 16:08:50 -0800 |
---|---|---|
committer | TensorFlower Gardener <gardener@tensorflow.org> | 2018-01-11 16:12:46 -0800 |
commit | febdd26ae594133d24f82544706b1e012a5cf1ea (patch) | |
tree | dd325008019ab10ce35f98368bf392ce4a118ec9 /tensorflow/contrib/tensorboard | |
parent | fc252eb976c98c95a625ea6e6a0486334d3c5b6e (diff) |
Add reservoir sampling to DB summary writer
This thing is kind of cool. It's able to turn a 350mB event log into a
35mB SQLite file at 80mBps with one Macbook core. Best of all, this was
accomplished using a normalized schema without the embedded protos.
PiperOrigin-RevId: 181676380
Diffstat (limited to 'tensorflow/contrib/tensorboard')
-rw-r--r-- | tensorflow/contrib/tensorboard/db/BUILD | 6 | ||||
-rw-r--r-- | tensorflow/contrib/tensorboard/db/schema.cc | 239 | ||||
-rw-r--r-- | tensorflow/contrib/tensorboard/db/summary_db_writer.cc | 1206 | ||||
-rw-r--r-- | tensorflow/contrib/tensorboard/db/summary_db_writer.h | 7 | ||||
-rw-r--r-- | tensorflow/contrib/tensorboard/db/summary_db_writer_test.cc | 71 |
5 files changed, 1058 insertions, 471 deletions
diff --git a/tensorflow/contrib/tensorboard/db/BUILD b/tensorflow/contrib/tensorboard/db/BUILD index 3a3402c59b..4c9cc4ccd6 100644 --- a/tensorflow/contrib/tensorboard/db/BUILD +++ b/tensorflow/contrib/tensorboard/db/BUILD @@ -5,12 +5,13 @@ package(default_visibility = ["//tensorflow:internal"]) licenses(["notice"]) # Apache 2.0 -load("//tensorflow:tensorflow.bzl", "tf_cc_test") +load("//tensorflow:tensorflow.bzl", "tf_cc_test", "tf_copts") cc_library( name = "schema", srcs = ["schema.cc"], hdrs = ["schema.h"], + copts = tf_copts(), deps = [ "//tensorflow/core:lib", "//tensorflow/core/lib/db:sqlite", @@ -32,8 +33,10 @@ cc_library( name = "summary_db_writer", srcs = ["summary_db_writer.cc"], hdrs = ["summary_db_writer.h"], + copts = tf_copts(), deps = [ ":schema", + "//tensorflow/core:framework", "//tensorflow/core:lib", "//tensorflow/core:lib_internal", "//tensorflow/core:protos_all_cc", @@ -47,6 +50,7 @@ tf_cc_test( size = "small", srcs = ["summary_db_writer_test.cc"], deps = [ + ":schema", ":summary_db_writer", "//tensorflow/core:lib", "//tensorflow/core:protos_all_cc", diff --git a/tensorflow/contrib/tensorboard/db/schema.cc b/tensorflow/contrib/tensorboard/db/schema.cc index 2cd00876f8..6ccd386dc0 100644 --- a/tensorflow/contrib/tensorboard/db/schema.cc +++ b/tensorflow/contrib/tensorboard/db/schema.cc @@ -22,8 +22,7 @@ namespace { Status Run(Sqlite* db, const char* sql) { SqliteStatement stmt; TF_RETURN_IF_ERROR(db->Prepare(sql, &stmt)); - TF_RETURN_IF_ERROR(stmt.StepAndReset()); - return Status::OK(); + return stmt.StepAndReset(); } } // namespace @@ -38,37 +37,34 @@ Status SetupTensorboardSqliteDb(Sqlite* db) { db->PrepareOrDie("PRAGMA user_version=0").StepAndResetOrDie(); Status s; - // Creates Ids table. + // Ids identify resources. // - // This table must be used to randomly allocate Permanent IDs for - // all top-level tables, in order to maintain an invariant where - // foo_id != bar_id for all IDs of any two tables. + // This table can be used to efficiently generate Permanent IDs in + // conjunction with a random number generator. Unlike rowids these + // IDs safe to use in URLs and unique across tables. // - // A row should only be deleted from this table if it can be - // guaranteed that it exists absolutely nowhere else in the entire - // system. + // Within any given system, there can't be any foo_id == bar_id for + // all rows of any two (Foos, Bars) tables. A row should only be + // deleted from this table if there's a very high level of confidence + // it exists nowhere else in the system. // // Fields: - // id: An ID that was allocated globally. This must be in the - // range [1,2**47). 0 is assigned the same meaning as NULL and - // shouldn't be stored; 2**63-1 is reserved for statically - // allocating space in a page to UPDATE later; and all other - // int64 values are reserved for future use. + // id: The system-wide ID. This must be in the range [1,2**47). 0 + // is assigned the same meaning as NULL and shouldn't be stored + // and all other int64 values are reserved for future use. Please + // note that id is also the rowid. s.Update(Run(db, R"sql( CREATE TABLE IF NOT EXISTS Ids ( id INTEGER PRIMARY KEY ) )sql")); - // Creates Descriptions table. - // - // This table allows TensorBoard to associate Markdown text with any - // object in the database that has a Permanent ID. + // Descriptions are Markdown text that can be associated with any + // resource that has a Permanent ID. // // Fields: - // id: The Permanent ID of the associated object. This is also the - // SQLite rowid. - // description: Arbitrary Markdown text. + // id: The foo_id of the associated row in Foos. + // description: Arbitrary NUL-terminated Markdown text. s.Update(Run(db, R"sql( CREATE TABLE IF NOT EXISTS Descriptions ( id INTEGER PRIMARY KEY, @@ -76,121 +72,136 @@ Status SetupTensorboardSqliteDb(Sqlite* db) { ) )sql")); - // Creates Tensors table. + // Tensors are 0..n-dimensional numbers or strings. // // Fields: - // rowid: Ephemeral b-tree ID dictating locality. - // tag_id: ID of associated Tag. + // rowid: Ephemeral b-tree ID. + // series: The Permanent ID of a different resource, e.g. tag_id. A + // tensor will be vacuumed if no series == foo_id exists for all + // rows of all Foos. When series is NULL this tensor may serve + // undefined purposes. This field should be set on placeholders. + // step: Arbitrary number to uniquely order tensors within series. + // The meaning of step is undefined when series is NULL. This may + // be set on placeholders to prepopulate index pages. // computed_time: Float UNIX timestamp with microsecond precision. // In the old summaries system that uses FileWriter, this is the // wall time around when tf.Session.run finished. In the new // summaries system, it is the wall time of when the tensor was // computed. On systems with monotonic clocks, it is calculated // by adding the monotonic run duration to Run.started_time. - // This field is not indexed because, in practice, it should be - // ordered the same or nearly the same as TensorIndex, so local - // insertion sort might be more suitable. - // step: User-supplied number, ordering this tensor in Tag. - // If NULL then the Tag must have only one Tensor. - // tensor: Can be an INTEGER (DT_INT64), FLOAT (DT_DOUBLE), or - // BLOB. The structure of a BLOB is currently undefined, but in - // essence it is a Snappy tf.TensorProto that spills over into - // TensorChunks. + // dtype: The tensorflow::DataType ID. For example, DT_INT64 is 9. + // When NULL or 0 this must be treated as a placeholder row that + // does not officially exist. + // shape: A comma-delimited list of int64 >=0 values representing + // length of each dimension in the tensor. This must be a valid + // shape. That means no -1 values and, in the case of numeric + // tensors, length(data) == product(shape) * sizeof(dtype). Empty + // means this is a scalar a.k.a. 0-dimensional tensor. + // data: Little-endian raw tensor memory. If dtype is DT_STRING and + // shape is empty, the nullness of this field indicates whether or + // not it contains the tensor contents; otherwise TensorStrings + // must be queried. If dtype is NULL then ZEROBLOB can be used on + // this field to reserve row space to be updated later. s.Update(Run(db, R"sql( CREATE TABLE IF NOT EXISTS Tensors ( rowid INTEGER PRIMARY KEY, - tag_id INTEGER NOT NULL, - computed_time REAL, + series INTEGER, step INTEGER, - tensor BLOB + dtype INTEGER, + computed_time REAL, + shape TEXT, + data BLOB ) )sql")); - // Uniquely indexes (tag_id, step) on Tensors table. s.Update(Run(db, R"sql( - CREATE UNIQUE INDEX IF NOT EXISTS TensorIndex - ON Tensors (tag_id, step) + CREATE UNIQUE INDEX IF NOT EXISTS + TensorSeriesStepIndex + ON + Tensors (series, step) + WHERE + series IS NOT NULL + AND step IS NOT NULL )sql")); - // Creates TensorChunks table. + // TensorStrings are the flat contents of 1..n dimensional DT_STRING + // Tensors. // - // This table can be used to split up a tensor across many rows, - // which has the advantage of not slowing down table scans on the - // main table, allowing asynchronous fetching, minimizing copying, - // and preventing large buffers from being allocated. + // The number of rows associated with a Tensor must be equal to the + // product of its Tensors.shape. // // Fields: - // rowid: Ephemeral b-tree ID dictating locality. - // tag_id: ID of associated Tag. - // step: Same as corresponding Tensors.step. - // sequence: 1-indexed sequence number for ordering chunks. Please - // note that the 0th index is Tensors.tensor. - // chunk: Bytes of next chunk in tensor. + // rowid: Ephemeral b-tree ID. + // tensor_rowid: References Tensors.rowid. + // idx: Index in flattened tensor, starting at 0. + // data: The string value at a particular index. NUL characters are + // permitted. s.Update(Run(db, R"sql( - CREATE TABLE IF NOT EXISTS TensorChunks ( + CREATE TABLE IF NOT EXISTS TensorStrings ( rowid INTEGER PRIMARY KEY, - tag_id INTEGER NOT NULL, - step INTEGER, - sequence INTEGER, - chunk BLOB + tensor_rowid INTEGER NOT NULL, + idx INTEGER NOT NULL, + data BLOB ) )sql")); - // Uniquely indexes (tag_id, step, sequence) on TensorChunks table. s.Update(Run(db, R"sql( - CREATE UNIQUE INDEX IF NOT EXISTS TensorChunkIndex - ON TensorChunks (tag_id, step, sequence) + CREATE UNIQUE INDEX IF NOT EXISTS TensorStringIndex + ON TensorStrings (tensor_rowid, idx) )sql")); - // Creates Tags table. + // Tags are series of Tensors. // // Fields: - // rowid: Ephemeral b-tree ID dictating locality. + // rowid: Ephemeral b-tree ID. // tag_id: The Permanent ID of the Tag. // run_id: Optional ID of associated Run. - // tag_name: The tag field in summary.proto, unique across Run. // inserted_time: Float UNIX timestamp with µs precision. This is // always the wall time of when the row was inserted into the // DB. It may be used as a hint for an archival job. + // tag_name: The tag field in summary.proto, unique across Run. // display_name: Optional for GUI and defaults to tag_name. // plugin_name: Arbitrary TensorBoard plugin name for dispatch. // plugin_data: Arbitrary data that plugin wants. + // + // TODO(jart): Maybe there should be a Plugins table? s.Update(Run(db, R"sql( CREATE TABLE IF NOT EXISTS Tags ( rowid INTEGER PRIMARY KEY, run_id INTEGER, tag_id INTEGER NOT NULL, - tag_name TEXT, inserted_time DOUBLE, + tag_name TEXT, display_name TEXT, plugin_name TEXT, plugin_data BLOB ) )sql")); - // Uniquely indexes tag_id on Tags table. s.Update(Run(db, R"sql( CREATE UNIQUE INDEX IF NOT EXISTS TagIdIndex ON Tags (tag_id) )sql")); - // Uniquely indexes (run_id, tag_name) on Tags table. s.Update(Run(db, R"sql( - CREATE UNIQUE INDEX IF NOT EXISTS TagNameIndex - ON Tags (run_id, tag_name) - WHERE tag_name IS NOT NULL + CREATE UNIQUE INDEX IF NOT EXISTS + TagRunNameIndex + ON + Tags (run_id, tag_name) + WHERE + run_id IS NOT NULL + AND tag_name IS NOT NULL )sql")); - // Creates Runs table. + // Runs are groups of Tags. // - // This table stores information about Runs. Each row usually - // represents a single attempt at training or testing a TensorFlow - // model, with a given set of hyper-parameters, whose summaries are - // written out to a single event logs directory with a monotonic step - // counter. + // Each Run usually represents a single attempt at training or testing + // a TensorFlow model, with a given set of hyper-parameters, whose + // summaries are written out to a single event logs directory with a + // monotonic step counter. // // Fields: - // rowid: Ephemeral b-tree ID dictating locality. + // rowid: Ephemeral b-tree ID. // run_id: The Permanent ID of the Run. This has a 1:1 mapping // with a SummaryWriter instance. If two writers spawn for a // given (user_name, run_name, run_name) then each should @@ -199,8 +210,8 @@ Status SetupTensorboardSqliteDb(Sqlite* db) { // previous invocations will then enter limbo, where they may be // accessible for certain operations, but should be garbage // collected eventually. - // experiment_id: Optional ID of associated Experiment. // run_name: User-supplied string, unique across Experiment. + // experiment_id: Optional ID of associated Experiment. // inserted_time: Float UNIX timestamp with µs precision. This is // always the time the row was inserted into the database. It // does not change. @@ -215,40 +226,33 @@ Status SetupTensorboardSqliteDb(Sqlite* db) { // SummaryWriter resource that created this run was destroyed. // Once this value becomes non-NULL a Run and its Tags and // Tensors should be regarded as immutable. - // graph_id: ID of associated Graphs row. s.Update(Run(db, R"sql( CREATE TABLE IF NOT EXISTS Runs ( rowid INTEGER PRIMARY KEY, experiment_id INTEGER, run_id INTEGER NOT NULL, - run_name TEXT, inserted_time REAL, started_time REAL, finished_time REAL, - graph_id INTEGER + run_name TEXT ) )sql")); - // Uniquely indexes run_id on Runs table. s.Update(Run(db, R"sql( CREATE UNIQUE INDEX IF NOT EXISTS RunIdIndex ON Runs (run_id) )sql")); - // Uniquely indexes (experiment_id, run_name) on Runs table. s.Update(Run(db, R"sql( CREATE UNIQUE INDEX IF NOT EXISTS RunNameIndex ON Runs (experiment_id, run_name) WHERE run_name IS NOT NULL )sql")); - // Creates Experiments table. - // - // This table stores information about experiments, which are sets of - // runs. + // Experiments are groups of Runs. // // Fields: - // rowid: Ephemeral b-tree ID dictating locality. + // rowid: Ephemeral b-tree ID. // user_id: Optional ID of associated User. // experiment_id: The Permanent ID of the Experiment. // experiment_name: User-supplied string, unique across User. @@ -259,34 +263,39 @@ Status SetupTensorboardSqliteDb(Sqlite* db) { // the MIN(experiment.started_time, run.started_time) of each // Run added to the database, including Runs which have since // been overwritten. + // is_watching: A boolean indicating if someone is actively + // looking at this Experiment in the TensorBoard GUI. Tensor + // writers that do reservoir sampling can query this value to + // decide if they want the "keep last" behavior. This improves + // the performance of long running training while allowing low + // latency feedback in TensorBoard. s.Update(Run(db, R"sql( CREATE TABLE IF NOT EXISTS Experiments ( rowid INTEGER PRIMARY KEY, user_id INTEGER, experiment_id INTEGER NOT NULL, - experiment_name TEXT, inserted_time REAL, - started_time REAL + started_time REAL, + is_watching INTEGER, + experiment_name TEXT ) )sql")); - // Uniquely indexes experiment_id on Experiments table. s.Update(Run(db, R"sql( CREATE UNIQUE INDEX IF NOT EXISTS ExperimentIdIndex ON Experiments (experiment_id) )sql")); - // Uniquely indexes (user_id, experiment_name) on Experiments table. s.Update(Run(db, R"sql( CREATE UNIQUE INDEX IF NOT EXISTS ExperimentNameIndex ON Experiments (user_id, experiment_name) WHERE experiment_name IS NOT NULL )sql")); - // Creates Users table. + // Users are people who love TensorBoard. // // Fields: - // rowid: Ephemeral b-tree ID dictating locality. + // rowid: Ephemeral b-tree ID. // user_id: The Permanent ID of the User. // user_name: Unique user name. // email: Optional unique email address. @@ -297,61 +306,66 @@ Status SetupTensorboardSqliteDb(Sqlite* db) { CREATE TABLE IF NOT EXISTS Users ( rowid INTEGER PRIMARY KEY, user_id INTEGER NOT NULL, + inserted_time REAL, user_name TEXT, - email TEXT, - inserted_time REAL + email TEXT ) )sql")); - // Uniquely indexes user_id on Users table. s.Update(Run(db, R"sql( CREATE UNIQUE INDEX IF NOT EXISTS UserIdIndex ON Users (user_id) )sql")); - // Uniquely indexes user_name on Users table. s.Update(Run(db, R"sql( CREATE UNIQUE INDEX IF NOT EXISTS UserNameIndex ON Users (user_name) WHERE user_name IS NOT NULL )sql")); - // Uniquely indexes email on Users table. s.Update(Run(db, R"sql( CREATE UNIQUE INDEX IF NOT EXISTS UserEmailIndex ON Users (email) WHERE email IS NOT NULL )sql")); - // Creates Graphs table. + // Graphs define how Tensors flowed in Runs. // // Fields: - // rowid: Ephemeral b-tree ID dictating locality. + // rowid: Ephemeral b-tree ID. + // run_id: The Permanent ID of the associated Run. Only one Graph + // can be associated with a Run. // graph_id: The Permanent ID of the Graph. // inserted_time: Float UNIX timestamp with µs precision. This is // always the wall time of when the row was inserted into the // DB. It may be used as a hint for an archival job. - // node_def: Contains Snappy tf.GraphDef proto. All fields will be - // cleared except those not expressed in SQL. + // node_def: Contains tf.GraphDef proto. All fields will be cleared + // except those not expressed in SQL. s.Update(Run(db, R"sql( CREATE TABLE IF NOT EXISTS Graphs ( rowid INTEGER PRIMARY KEY, + run_id INTEGER, graph_id INTEGER NOT NULL, inserted_time REAL, graph_def BLOB ) )sql")); - // Uniquely indexes graph_id on Graphs table. s.Update(Run(db, R"sql( CREATE UNIQUE INDEX IF NOT EXISTS GraphIdIndex ON Graphs (graph_id) )sql")); - // Creates Nodes table. + s.Update(Run(db, R"sql( + CREATE UNIQUE INDEX IF NOT EXISTS GraphRunIndex + ON Graphs (run_id) + WHERE run_id IS NOT NULL + )sql")); + + // Nodes are the vertices in Graphs. // // Fields: - // rowid: Ephemeral b-tree ID dictating locality. + // rowid: Ephemeral b-tree ID. // graph_id: The Permanent ID of the associated Graph. // node_id: ID for this node. This is more like a 0-index within // the Graph. Please note indexes are allowed to be removed. @@ -361,8 +375,10 @@ Status SetupTensorboardSqliteDb(Sqlite* db) { // node_def.name proto field must not be cleared. // op: Copied from tf.NodeDef proto. // device: Copied from tf.NodeDef proto. - // node_def: Contains Snappy tf.NodeDef proto. All fields will be - // cleared except those not expressed in SQL. + // node_def: Contains tf.NodeDef proto. All fields will be cleared + // except those not expressed in SQL. + // + // TODO(jart): Make separate tables for op and device strings. s.Update(Run(db, R"sql( CREATE TABLE IF NOT EXISTS Nodes ( rowid INTEGER PRIMARY KEY, @@ -375,32 +391,35 @@ Status SetupTensorboardSqliteDb(Sqlite* db) { ) )sql")); - // Uniquely indexes (graph_id, node_id) on Nodes table. s.Update(Run(db, R"sql( CREATE UNIQUE INDEX IF NOT EXISTS NodeIdIndex ON Nodes (graph_id, node_id) )sql")); - // Uniquely indexes (graph_id, node_name) on Nodes table. s.Update(Run(db, R"sql( CREATE UNIQUE INDEX IF NOT EXISTS NodeNameIndex ON Nodes (graph_id, node_name) WHERE node_name IS NOT NULL )sql")); - // Creates NodeInputs table. + // NodeInputs are directed edges between Nodes in Graphs. // // Fields: - // rowid: Ephemeral b-tree ID dictating locality. + // rowid: Ephemeral b-tree ID. // graph_id: The Permanent ID of the associated Graph. // node_id: Index of Node in question. This can be considered the // 'to' vertex. // idx: Used for ordering inputs on a given Node. // input_node_id: Nodes.node_id of the corresponding input node. // This can be considered the 'from' vertex. + // input_node_idx: Since a Node can output multiple Tensors, this + // is the integer index of which of those outputs is our input. + // NULL is treated as 0. // is_control: If non-zero, indicates this input is a controlled // dependency, which means this isn't an edge through which // tensors flow. NULL means 0. + // + // TODO(jart): Rename to NodeEdges. s.Update(Run(db, R"sql( CREATE TABLE IF NOT EXISTS NodeInputs ( rowid INTEGER PRIMARY KEY, @@ -408,11 +427,11 @@ Status SetupTensorboardSqliteDb(Sqlite* db) { node_id INTEGER NOT NULL, idx INTEGER NOT NULL, input_node_id INTEGER NOT NULL, + input_node_idx INTEGER, is_control INTEGER ) )sql")); - // Uniquely indexes (graph_id, node_id, idx) on NodeInputs table. s.Update(Run(db, R"sql( CREATE UNIQUE INDEX IF NOT EXISTS NodeInputsIndex ON NodeInputs (graph_id, node_id, idx) diff --git a/tensorflow/contrib/tensorboard/db/summary_db_writer.cc b/tensorflow/contrib/tensorboard/db/summary_db_writer.cc index 44887930c1..889ac43415 100644 --- a/tensorflow/contrib/tensorboard/db/summary_db_writer.cc +++ b/tensorflow/contrib/tensorboard/db/summary_db_writer.cc @@ -14,17 +14,37 @@ limitations under the License. ==============================================================================*/ #include "tensorflow/contrib/tensorboard/db/summary_db_writer.h" -#include "tensorflow/contrib/tensorboard/db/schema.h" #include "tensorflow/core/framework/graph.pb.h" #include "tensorflow/core/framework/node_def.pb.h" +#include "tensorflow/core/framework/register_types.h" #include "tensorflow/core/framework/summary.pb.h" #include "tensorflow/core/lib/core/stringpiece.h" #include "tensorflow/core/lib/db/sqlite.h" #include "tensorflow/core/lib/random/random.h" -#include "tensorflow/core/lib/strings/stringprintf.h" -#include "tensorflow/core/platform/fingerprint.h" #include "tensorflow/core/util/event.pb.h" +// TODO(jart): Break this up into multiple files with excellent unit tests. +// TODO(jart): Make decision to write in separate op. +// TODO(jart): Add really good busy handling. + +// clang-format off +#define CALL_SUPPORTED_TYPES(m) \ + TF_CALL_string(m) \ + TF_CALL_half(m) \ + TF_CALL_float(m) \ + TF_CALL_double(m) \ + TF_CALL_complex64(m) \ + TF_CALL_complex128(m) \ + TF_CALL_int8(m) \ + TF_CALL_int16(m) \ + TF_CALL_int32(m) \ + TF_CALL_int64(m) \ + TF_CALL_uint8(m) \ + TF_CALL_uint16(m) \ + TF_CALL_uint32(m) \ + TF_CALL_uint64(m) +// clang-format on + namespace tensorflow { namespace { @@ -33,115 +53,145 @@ const uint64 kIdTiers[] = { 0x7fffffULL, // 23-bit (3 bytes on disk) 0x7fffffffULL, // 31-bit (4 bytes on disk) 0x7fffffffffffULL, // 47-bit (5 bytes on disk) - // Remaining bits reserved for future use. + // remaining bits for future use }; const int kMaxIdTier = sizeof(kIdTiers) / sizeof(uint64); const int kIdCollisionDelayMicros = 10; const int kMaxIdCollisions = 21; // sum(2**i*10µs for i in range(21))~=21s const int64 kAbsent = 0LL; -const int64 kReserved = 0x7fffffffffffffffLL; -double GetWallTime(Env* env) { +const char* kScalarPluginName = "scalars"; +const char* kImagePluginName = "images"; +const char* kAudioPluginName = "audio"; +const char* kHistogramPluginName = "histograms"; + +const int kScalarSlots = 10000; +const int kImageSlots = 10; +const int kAudioSlots = 10; +const int kHistogramSlots = 1; +const int kTensorSlots = 10; + +const int64 kReserveMinBytes = 32; +const double kReserveMultiplier = 1.5; + +// Flush is a misnomer because what we're actually doing is having lots +// of commits inside any SqliteTransaction that writes potentially +// hundreds of megs but doesn't need the transaction to maintain its +// invariants. This ensures the WAL read penalty is small and might +// allow writers in other processes a chance to schedule. +const uint64 kFlushBytes = 1024 * 1024; + +double DoubleTime(uint64 micros) { // TODO(@jart): Follow precise definitions for time laid out in schema. // TODO(@jart): Use monotonic clock from gRPC codebase. - return static_cast<double>(env->NowMicros()) / 1.0e6; + return static_cast<double>(micros) / 1.0e6; } -Status Serialize(const protobuf::MessageLite& proto, string* output) { - output->clear(); - if (!proto.SerializeToString(output)) { - return errors::DataLoss("SerializeToString failed"); +string StringifyShape(const TensorShape& shape) { + string result; + bool first = true; + for (const auto& dim : shape) { + if (first) { + first = false; + } else { + strings::StrAppend(&result, ","); + } + strings::StrAppend(&result, dim.size); } - return Status::OK(); + return result; } -Status BindProto(SqliteStatement* stmt, int parameter, - const protobuf::MessageLite& proto) { - string serialized; - TF_RETURN_IF_ERROR(Serialize(proto, &serialized)); - stmt->BindBlob(parameter, serialized); +Status CheckSupportedType(const Tensor& t) { +#define CASE(T) \ + case DataTypeToEnum<T>::value: \ + break; + switch (t.dtype()) { + CALL_SUPPORTED_TYPES(CASE) + default: + return errors::Unimplemented(DataTypeString(t.dtype()), + " tensors unsupported on platform"); + } return Status::OK(); +#undef CASE } -Status BindTensor(SqliteStatement* stmt, int parameter, const Tensor& t) { - // TODO(@jart): Make portable between little and big endian systems. - // TODO(@jart): Use TensorChunks with minimal copying for big tensors. - // TODO(@jart): Add field to indicate encoding. - TensorProto p; - t.AsProtoTensorContent(&p); - return BindProto(stmt, parameter, p); -} - -// Tries to fudge shape and dtype to something with smaller storage. -Status CoerceScalar(const Tensor& t, Tensor* out) { +Tensor AsScalar(const Tensor& t) { + Tensor t2{t.dtype(), {}}; +#define CASE(T) \ + case DataTypeToEnum<T>::value: \ + t2.scalar<T>()() = t.flat<T>()(0); \ + break; switch (t.dtype()) { - case DT_DOUBLE: - *out = t; - break; - case DT_INT64: - *out = t; - break; - case DT_FLOAT: - *out = {DT_DOUBLE, {}}; - out->scalar<double>()() = t.scalar<float>()(); - break; - case DT_HALF: - *out = {DT_DOUBLE, {}}; - out->scalar<double>()() = static_cast<double>(t.scalar<Eigen::half>()()); - break; - case DT_INT32: - *out = {DT_INT64, {}}; - out->scalar<int64>()() = t.scalar<int32>()(); - break; - case DT_INT16: - *out = {DT_INT64, {}}; - out->scalar<int64>()() = t.scalar<int16>()(); - break; - case DT_INT8: - *out = {DT_INT64, {}}; - out->scalar<int64>()() = t.scalar<int8>()(); - break; - case DT_UINT32: - *out = {DT_INT64, {}}; - out->scalar<int64>()() = t.scalar<uint32>()(); - break; - case DT_UINT16: - *out = {DT_INT64, {}}; - out->scalar<int64>()() = t.scalar<uint16>()(); - break; - case DT_UINT8: - *out = {DT_INT64, {}}; - out->scalar<int64>()() = t.scalar<uint8>()(); - break; + CALL_SUPPORTED_TYPES(CASE) default: - return errors::Unimplemented("Scalar summary for dtype ", - DataTypeString(t.dtype()), - " is not supported."); + t2 = {DT_FLOAT, {}}; + t2.scalar<float>()() = NAN; + break; } - return Status::OK(); + return t2; +#undef CASE +} + +void PatchPluginName(SummaryMetadata* metadata, const char* name) { + if (metadata->plugin_data().plugin_name().empty()) { + metadata->mutable_plugin_data()->set_plugin_name(name); + } +} + +int GetSlots(const Tensor& t, const SummaryMetadata& metadata) { + if (metadata.plugin_data().plugin_name() == kScalarPluginName) { + return kScalarSlots; + } else if (metadata.plugin_data().plugin_name() == kImagePluginName) { + return kImageSlots; + } else if (metadata.plugin_data().plugin_name() == kAudioPluginName) { + return kAudioSlots; + } else if (metadata.plugin_data().plugin_name() == kHistogramPluginName) { + return kHistogramSlots; + } else if (t.dims() == 0 && t.dtype() != DT_STRING) { + return kScalarSlots; + } else { + return kTensorSlots; + } +} + +Status SetDescription(Sqlite* db, int64 id, const StringPiece& markdown) { + const char* sql = R"sql( + INSERT OR REPLACE INTO Descriptions (id, description) VALUES (?, ?) + )sql"; + SqliteStatement insert_desc; + TF_RETURN_IF_ERROR(db->Prepare(sql, &insert_desc)); + insert_desc.BindInt(1, id); + insert_desc.BindText(2, markdown); + return insert_desc.StepAndReset(); } -/// \brief Generates unique IDs randomly in the [1,2**63-2] range. +/// \brief Generates unique IDs randomly in the [1,2**63-1] range. /// /// This class starts off generating IDs in the [1,2**23-1] range, /// because it's human friendly and occupies 4 bytes max on disk with /// SQLite's zigzag varint encoding. Then, each time a collision /// happens, the random space is increased by 8 bits. /// -/// This class uses exponential back-off so writes will slow down as -/// the ID space becomes exhausted. +/// This class uses exponential back-off so writes gradually slow down +/// as IDs become exhausted but reads are still possible. +/// +/// This class is thread safe. class IdAllocator { public: - IdAllocator(Env* env, Sqlite* db) - : env_{env}, - inserter_{db->PrepareOrDie("INSERT INTO Ids (id) VALUES (?)")} {} + IdAllocator(Env* env, Sqlite* db) : env_{env}, db_{db} { + DCHECK(env_ != nullptr); + DCHECK(db_ != nullptr); + } - Status CreateNewId(int64* id) { + Status CreateNewId(int64* id) LOCKS_EXCLUDED(mu_) { + mutex_lock lock(mu_); Status s; + SqliteStatement stmt; + TF_RETURN_IF_ERROR(db_->Prepare("INSERT INTO Ids (id) VALUES (?)", &stmt)); for (int i = 0; i < kMaxIdCollisions; ++i) { int64 tid = MakeRandomId(); - inserter_.BindInt(1, tid); - s = inserter_.StepAndReset(); + stmt.BindInt(1, tid); + s = stmt.StepAndReset(); if (s.ok()) { *id = tid; break; @@ -167,34 +217,38 @@ class IdAllocator { } private: - int64 MakeRandomId() { + int64 MakeRandomId() EXCLUSIVE_LOCKS_REQUIRED(mu_) { int64 id = static_cast<int64>(random::New64() & kIdTiers[tier_]); if (id == kAbsent) ++id; - if (id == kReserved) --id; return id; } - Env* env_; - SqliteStatement inserter_; - int tier_ = 0; + mutex mu_; + Env* const env_; + Sqlite* const db_; + int tier_ GUARDED_BY(mu_) = 0; + + TF_DISALLOW_COPY_AND_ASSIGN(IdAllocator); }; -class GraphSaver { +class GraphWriter { public: - static Status Save(Env* env, Sqlite* db, IdAllocator* id_allocator, - GraphDef* graph, int64* graph_id) { - TF_RETURN_IF_ERROR(id_allocator->CreateNewId(graph_id)); - GraphSaver saver{env, db, graph, *graph_id}; + static Status Save(Sqlite* db, SqliteTransaction* txn, IdAllocator* ids, + GraphDef* graph, uint64 now, int64 run_id, int64* graph_id) + SQLITE_EXCLUSIVE_TRANSACTIONS_REQUIRED(*db) { + TF_RETURN_IF_ERROR(ids->CreateNewId(graph_id)); + GraphWriter saver{db, txn, graph, now, *graph_id}; saver.MapNameToNodeId(); - TF_RETURN_IF_ERROR(saver.SaveNodeInputs()); - TF_RETURN_IF_ERROR(saver.SaveNodes()); - TF_RETURN_IF_ERROR(saver.SaveGraph()); + TF_RETURN_WITH_CONTEXT_IF_ERROR(saver.SaveNodeInputs(), "SaveNodeInputs"); + TF_RETURN_WITH_CONTEXT_IF_ERROR(saver.SaveNodes(), "SaveNodes"); + TF_RETURN_WITH_CONTEXT_IF_ERROR(saver.SaveGraph(run_id), "SaveGraph"); return Status::OK(); } private: - GraphSaver(Env* env, Sqlite* db, GraphDef* graph, int64 graph_id) - : env_(env), db_(db), graph_(graph), graph_id_(graph_id) {} + GraphWriter(Sqlite* db, SqliteTransaction* txn, GraphDef* graph, uint64 now, + int64 graph_id) + : db_(db), txn_(txn), graph_(graph), now_(now), graph_id_(graph_id) {} void MapNameToNodeId() { size_t toto = static_cast<size_t>(graph_->node_size()); @@ -209,161 +263,193 @@ class GraphSaver { } Status SaveNodeInputs() { - auto insert = db_->PrepareOrDie(R"sql( - INSERT INTO NodeInputs (graph_id, node_id, idx, input_node_id, is_control) - VALUES (?, ?, ?, ?, ?) - )sql"); + const char* sql = R"sql( + INSERT INTO NodeInputs ( + graph_id, + node_id, + idx, + input_node_id, + input_node_idx, + is_control + ) VALUES (?, ?, ?, ?, ?, ?) + )sql"; + SqliteStatement insert; + TF_RETURN_IF_ERROR(db_->Prepare(sql, &insert)); for (int node_id = 0; node_id < graph_->node_size(); ++node_id) { const NodeDef& node = graph_->node(node_id); for (int idx = 0; idx < node.input_size(); ++idx) { StringPiece name = node.input(idx); - insert.BindInt(1, graph_id_); - insert.BindInt(2, node_id); - insert.BindInt(3, idx); + int64 input_node_id; + int64 input_node_idx = 0; + int64 is_control = 0; + size_t i = name.rfind(':'); + if (i != StringPiece::npos) { + if (!strings::safe_strto64(name.substr(i + 1, name.size() - i - 1), + &input_node_idx)) { + return errors::DataLoss("Bad NodeDef.input: ", name); + } + name.remove_suffix(name.size() - i); + } if (!name.empty() && name[0] == '^') { name.remove_prefix(1); - insert.BindInt(5, 1); + is_control = 1; } auto e = name_to_node_id_.find(name); if (e == name_to_node_id_.end()) { return errors::DataLoss("Could not find node: ", name); } - insert.BindInt(4, e->second); + input_node_id = e->second; + insert.BindInt(1, graph_id_); + insert.BindInt(2, node_id); + insert.BindInt(3, idx); + insert.BindInt(4, input_node_id); + insert.BindInt(5, input_node_idx); + insert.BindInt(6, is_control); + unflushed_bytes_ += insert.size(); TF_RETURN_WITH_CONTEXT_IF_ERROR(insert.StepAndReset(), node.name(), " -> ", name); + TF_RETURN_IF_ERROR(MaybeFlush()); } } return Status::OK(); } Status SaveNodes() { - auto insert = db_->PrepareOrDie(R"sql( - INSERT INTO Nodes (graph_id, node_id, node_name, op, device, node_def) - VALUES (?, ?, ?, ?, ?, snap(?)) - )sql"); + const char* sql = R"sql( + INSERT INTO Nodes ( + graph_id, + node_id, + node_name, + op, + device, + node_def) + VALUES (?, ?, ?, ?, ?, ?) + )sql"; + SqliteStatement insert; + TF_RETURN_IF_ERROR(db_->Prepare(sql, &insert)); for (int node_id = 0; node_id < graph_->node_size(); ++node_id) { NodeDef* node = graph_->mutable_node(node_id); insert.BindInt(1, graph_id_); insert.BindInt(2, node_id); insert.BindText(3, node->name()); + insert.BindText(4, node->op()); + insert.BindText(5, node->device()); node->clear_name(); - if (!node->op().empty()) { - insert.BindText(4, node->op()); - node->clear_op(); - } - if (!node->device().empty()) { - insert.BindText(5, node->device()); - node->clear_device(); - } + node->clear_op(); + node->clear_device(); node->clear_input(); - TF_RETURN_IF_ERROR(BindProto(&insert, 6, *node)); + string node_def; + if (node->SerializeToString(&node_def)) { + insert.BindBlobUnsafe(6, node_def); + } + unflushed_bytes_ += insert.size(); TF_RETURN_WITH_CONTEXT_IF_ERROR(insert.StepAndReset(), node->name()); + TF_RETURN_IF_ERROR(MaybeFlush()); } return Status::OK(); } - Status SaveGraph() { - auto insert = db_->PrepareOrDie(R"sql( - INSERT INTO Graphs (graph_id, inserted_time, graph_def) - VALUES (?, ?, snap(?)) - )sql"); - insert.BindInt(1, graph_id_); - insert.BindDouble(2, GetWallTime(env_)); + Status SaveGraph(int64 run_id) { + const char* sql = R"sql( + INSERT OR REPLACE INTO Graphs ( + run_id, + graph_id, + inserted_time, + graph_def + ) VALUES (?, ?, ?, ?) + )sql"; + SqliteStatement insert; + TF_RETURN_IF_ERROR(db_->Prepare(sql, &insert)); + if (run_id != kAbsent) insert.BindInt(1, run_id); + insert.BindInt(2, graph_id_); + insert.BindDouble(3, DoubleTime(now_)); graph_->clear_node(); - TF_RETURN_IF_ERROR(BindProto(&insert, 3, *graph_)); + string graph_def; + if (graph_->SerializeToString(&graph_def)) { + insert.BindBlobUnsafe(4, graph_def); + } return insert.StepAndReset(); } - Env* env_; - Sqlite* db_; - GraphDef* graph_; - int64 graph_id_; + Status MaybeFlush() { + if (unflushed_bytes_ >= kFlushBytes) { + TF_RETURN_WITH_CONTEXT_IF_ERROR(txn_->Commit(), "flushing ", + unflushed_bytes_, " bytes"); + unflushed_bytes_ = 0; + } + return Status::OK(); + } + + Sqlite* const db_; + SqliteTransaction* const txn_; + uint64 unflushed_bytes_ = 0; + GraphDef* const graph_; + const uint64 now_; + const int64 graph_id_; std::vector<string> name_copies_; std::unordered_map<StringPiece, int64, StringPieceHasher> name_to_node_id_; + + TF_DISALLOW_COPY_AND_ASSIGN(GraphWriter); }; -class RunWriter { +/// \brief Run metadata manager. +/// +/// This class gives us Tag IDs we can pass to SeriesWriter. In order +/// to do that, rows are created in the Ids, Tags, Runs, Experiments, +/// and Users tables. +/// +/// This class is thread safe. +class RunMetadata { public: - RunWriter(Env* env, Sqlite* db, const string& experiment_name, - const string& run_name, const string& user_name) - : env_{env}, - db_{db}, - id_allocator_{env_, db_}, + RunMetadata(IdAllocator* ids, const string& experiment_name, + const string& run_name, const string& user_name) + : ids_{ids}, experiment_name_{experiment_name}, run_name_{run_name}, - user_name_{user_name}, - insert_tensor_{db_->PrepareOrDie(R"sql( - INSERT OR REPLACE INTO Tensors (tag_id, step, computed_time, tensor) - VALUES (?, ?, ?, snap(?)) - )sql")} { - db_->Ref(); + user_name_{user_name} { + DCHECK(ids_ != nullptr); } - ~RunWriter() { - if (run_id_ != kAbsent) { - auto update = db_->PrepareOrDie(R"sql( - UPDATE Runs SET finished_time = ? WHERE run_id = ? - )sql"); - update.BindDouble(1, GetWallTime(env_)); - update.BindInt(2, run_id_); - Status s = update.StepAndReset(); - if (!s.ok()) { - LOG(ERROR) << "Failed to set Runs[" << run_id_ - << "].finish_time: " << s.ToString(); - } - } - db_->Unref(); - } + const string& experiment_name() { return experiment_name_; } + const string& run_name() { return run_name_; } + const string& user_name() { return user_name_; } - Status InsertTensor(int64 tag_id, int64 step, double computed_time, - Tensor t) { - insert_tensor_.BindInt(1, tag_id); - insert_tensor_.BindInt(2, step); - insert_tensor_.BindDouble(3, computed_time); - if (t.shape().dims() == 0 && t.dtype() == DT_INT64) { - insert_tensor_.BindInt(4, t.scalar<int64>()()); - } else if (t.shape().dims() == 0 && t.dtype() == DT_DOUBLE) { - insert_tensor_.BindDouble(4, t.scalar<double>()()); - } else { - TF_RETURN_IF_ERROR(BindTensor(&insert_tensor_, 4, t)); - } - return insert_tensor_.StepAndReset(); + int64 run_id() LOCKS_EXCLUDED(mu_) { + mutex_lock lock(mu_); + return run_id_; } - Status InsertGraph(std::unique_ptr<GraphDef> g, double computed_time) { - TF_RETURN_IF_ERROR(InitializeRun(computed_time)); + Status SetGraph(Sqlite* db, uint64 now, double computed_time, + std::unique_ptr<GraphDef> g) SQLITE_TRANSACTIONS_EXCLUDED(*db) + LOCKS_EXCLUDED(mu_) { + int64 run_id; + { + mutex_lock lock(mu_); + TF_RETURN_IF_ERROR(InitializeRun(db, now, computed_time)); + run_id = run_id_; + } int64 graph_id; + SqliteTransaction txn(*db); // only to increase performance TF_RETURN_IF_ERROR( - GraphSaver::Save(env_, db_, &id_allocator_, g.get(), &graph_id)); - if (run_id_ != kAbsent) { - auto set = - db_->PrepareOrDie("UPDATE Runs SET graph_id = ? WHERE run_id = ?"); - set.BindInt(1, graph_id); - set.BindInt(2, run_id_); - TF_RETURN_IF_ERROR(set.StepAndReset()); - } - return Status::OK(); + GraphWriter::Save(db, &txn, ids_, g.get(), now, run_id, &graph_id)); + return txn.Commit(); } - Status GetTagId(double computed_time, const string& tag_name, - const SummaryMetadata& metadata, int64* tag_id) { - TF_RETURN_IF_ERROR(InitializeRun(computed_time)); + Status GetTagId(Sqlite* db, uint64 now, double computed_time, + const string& tag_name, int64* tag_id, + const SummaryMetadata& metadata) LOCKS_EXCLUDED(mu_) { + mutex_lock lock(mu_); + TF_RETURN_IF_ERROR(InitializeRun(db, now, computed_time)); auto e = tag_ids_.find(tag_name); if (e != tag_ids_.end()) { *tag_id = e->second; return Status::OK(); } - TF_RETURN_IF_ERROR(id_allocator_.CreateNewId(tag_id)); + TF_RETURN_IF_ERROR(ids_->CreateNewId(tag_id)); tag_ids_[tag_name] = *tag_id; - if (!metadata.summary_description().empty()) { - SqliteStatement insert_description = db_->PrepareOrDie(R"sql( - INSERT INTO Descriptions (id, description) VALUES (?, ?) - )sql"); - insert_description.BindInt(1, *tag_id); - insert_description.BindText(2, metadata.summary_description()); - TF_RETURN_IF_ERROR(insert_description.StepAndReset()); - } - SqliteStatement insert = db_->PrepareOrDie(R"sql( + TF_RETURN_IF_ERROR( + SetDescription(db, *tag_id, metadata.summary_description())); + const char* sql = R"sql( INSERT INTO Tags ( run_id, tag_id, @@ -372,30 +458,54 @@ class RunWriter { display_name, plugin_name, plugin_data - ) VALUES (?, ?, ?, ?, ?, ?, ?) - )sql"); - if (run_id_ != kAbsent) insert.BindInt(1, run_id_); - insert.BindInt(2, *tag_id); - insert.BindText(3, tag_name); - insert.BindDouble(4, GetWallTime(env_)); - if (!metadata.display_name().empty()) { - insert.BindText(5, metadata.display_name()); - } - if (!metadata.plugin_data().plugin_name().empty()) { - insert.BindText(6, metadata.plugin_data().plugin_name()); - } - if (!metadata.plugin_data().content().empty()) { - insert.BindBlob(7, metadata.plugin_data().content()); - } + ) VALUES ( + :run_id, + :tag_id, + :tag_name, + :inserted_time, + :display_name, + :plugin_name, + :plugin_data + ) + )sql"; + SqliteStatement insert; + TF_RETURN_IF_ERROR(db->Prepare(sql, &insert)); + if (run_id_ != kAbsent) insert.BindInt(":run_id", run_id_); + insert.BindInt(":tag_id", *tag_id); + insert.BindTextUnsafe(":tag_name", tag_name); + insert.BindDouble(":inserted_time", DoubleTime(now)); + insert.BindTextUnsafe(":display_name", metadata.display_name()); + insert.BindTextUnsafe(":plugin_name", metadata.plugin_data().plugin_name()); + insert.BindBlobUnsafe(":plugin_data", metadata.plugin_data().content()); return insert.StepAndReset(); } + Status GetIsWatching(Sqlite* db, bool* is_watching) + SQLITE_TRANSACTIONS_EXCLUDED(*db) LOCKS_EXCLUDED(mu_) { + mutex_lock lock(mu_); + if (experiment_id_ == kAbsent) { + *is_watching = true; + return Status::OK(); + } + const char* sql = R"sql( + SELECT is_watching FROM Experiments WHERE experiment_id = ? + )sql"; + SqliteStatement stmt; + TF_RETURN_IF_ERROR(db->Prepare(sql, &stmt)); + stmt.BindInt(1, experiment_id_); + TF_RETURN_IF_ERROR(stmt.StepOnce()); + *is_watching = stmt.ColumnInt(0) != 0; + return Status::OK(); + } + private: - Status InitializeUser() { + Status InitializeUser(Sqlite* db, uint64 now) EXCLUSIVE_LOCKS_REQUIRED(mu_) { if (user_id_ != kAbsent || user_name_.empty()) return Status::OK(); - SqliteStatement get = db_->PrepareOrDie(R"sql( + const char* get_sql = R"sql( SELECT user_id FROM Users WHERE user_name = ? - )sql"); + )sql"; + SqliteStatement get; + TF_RETURN_IF_ERROR(db->Prepare(get_sql, &get)); get.BindText(1, user_name_); bool is_done; TF_RETURN_IF_ERROR(get.Step(&is_done)); @@ -403,22 +513,29 @@ class RunWriter { user_id_ = get.ColumnInt(0); return Status::OK(); } - TF_RETURN_IF_ERROR(id_allocator_.CreateNewId(&user_id_)); - SqliteStatement insert = db_->PrepareOrDie(R"sql( - INSERT INTO Users (user_id, user_name, inserted_time) VALUES (?, ?, ?) - )sql"); + TF_RETURN_IF_ERROR(ids_->CreateNewId(&user_id_)); + const char* insert_sql = R"sql( + INSERT INTO Users ( + user_id, + user_name, + inserted_time + ) VALUES (?, ?, ?) + )sql"; + SqliteStatement insert; + TF_RETURN_IF_ERROR(db->Prepare(insert_sql, &insert)); insert.BindInt(1, user_id_); insert.BindText(2, user_name_); - insert.BindDouble(3, GetWallTime(env_)); + insert.BindDouble(3, DoubleTime(now)); TF_RETURN_IF_ERROR(insert.StepAndReset()); return Status::OK(); } - Status InitializeExperiment(double computed_time) { + Status InitializeExperiment(Sqlite* db, uint64 now, double computed_time) + EXCLUSIVE_LOCKS_REQUIRED(mu_) { if (experiment_name_.empty()) return Status::OK(); if (experiment_id_ == kAbsent) { - TF_RETURN_IF_ERROR(InitializeUser()); - SqliteStatement get = db_->PrepareOrDie(R"sql( + TF_RETURN_IF_ERROR(InitializeUser(db, now)); + const char* get_sql = R"sql( SELECT experiment_id, started_time @@ -427,7 +544,9 @@ class RunWriter { WHERE user_id IS ? AND experiment_name = ? - )sql"); + )sql"; + SqliteStatement get; + TF_RETURN_IF_ERROR(db->Prepare(get_sql, &get)); if (user_id_ != kAbsent) get.BindInt(1, user_id_); get.BindText(2, experiment_name_); bool is_done; @@ -436,30 +555,41 @@ class RunWriter { experiment_id_ = get.ColumnInt(0); experiment_started_time_ = get.ColumnInt(1); } else { - TF_RETURN_IF_ERROR(id_allocator_.CreateNewId(&experiment_id_)); + TF_RETURN_IF_ERROR(ids_->CreateNewId(&experiment_id_)); experiment_started_time_ = computed_time; - SqliteStatement insert = db_->PrepareOrDie(R"sql( + const char* insert_sql = R"sql( INSERT INTO Experiments ( user_id, experiment_id, experiment_name, inserted_time, - started_time - ) VALUES (?, ?, ?, ?, ?) - )sql"); + started_time, + is_watching + ) VALUES (?, ?, ?, ?, ?, ?) + )sql"; + SqliteStatement insert; + TF_RETURN_IF_ERROR(db->Prepare(insert_sql, &insert)); if (user_id_ != kAbsent) insert.BindInt(1, user_id_); insert.BindInt(2, experiment_id_); insert.BindText(3, experiment_name_); - insert.BindDouble(4, GetWallTime(env_)); + insert.BindDouble(4, DoubleTime(now)); insert.BindDouble(5, computed_time); + insert.BindInt(6, 0); TF_RETURN_IF_ERROR(insert.StepAndReset()); } } if (computed_time < experiment_started_time_) { experiment_started_time_ = computed_time; - SqliteStatement update = db_->PrepareOrDie(R"sql( - UPDATE Experiments SET started_time = ? WHERE experiment_id = ? - )sql"); + const char* update_sql = R"sql( + UPDATE + Experiments + SET + started_time = ? + WHERE + experiment_id = ? + )sql"; + SqliteStatement update; + TF_RETURN_IF_ERROR(db->Prepare(update_sql, &update)); update.BindDouble(1, computed_time); update.BindInt(2, experiment_id_); TF_RETURN_IF_ERROR(update.StepAndReset()); @@ -467,13 +597,14 @@ class RunWriter { return Status::OK(); } - Status InitializeRun(double computed_time) { + Status InitializeRun(Sqlite* db, uint64 now, double computed_time) + EXCLUSIVE_LOCKS_REQUIRED(mu_) { if (run_name_.empty()) return Status::OK(); - TF_RETURN_IF_ERROR(InitializeExperiment(computed_time)); + TF_RETURN_IF_ERROR(InitializeExperiment(db, now, computed_time)); if (run_id_ == kAbsent) { - TF_RETURN_IF_ERROR(id_allocator_.CreateNewId(&run_id_)); + TF_RETURN_IF_ERROR(ids_->CreateNewId(&run_id_)); run_started_time_ = computed_time; - SqliteStatement insert = db_->PrepareOrDie(R"sql( + const char* insert_sql = R"sql( INSERT OR REPLACE INTO Runs ( experiment_id, run_id, @@ -481,19 +612,28 @@ class RunWriter { inserted_time, started_time ) VALUES (?, ?, ?, ?, ?) - )sql"); + )sql"; + SqliteStatement insert; + TF_RETURN_IF_ERROR(db->Prepare(insert_sql, &insert)); if (experiment_id_ != kAbsent) insert.BindInt(1, experiment_id_); insert.BindInt(2, run_id_); insert.BindText(3, run_name_); - insert.BindDouble(4, GetWallTime(env_)); + insert.BindDouble(4, DoubleTime(now)); insert.BindDouble(5, computed_time); TF_RETURN_IF_ERROR(insert.StepAndReset()); } if (computed_time < run_started_time_) { run_started_time_ = computed_time; - SqliteStatement update = db_->PrepareOrDie(R"sql( - UPDATE Runs SET started_time = ? WHERE run_id = ? - )sql"); + const char* update_sql = R"sql( + UPDATE + Runs + SET + started_time = ? + WHERE + run_id = ? + )sql"; + SqliteStatement update; + TF_RETURN_IF_ERROR(db->Prepare(update_sql, &update)); update.BindDouble(1, computed_time); update.BindInt(2, run_id_); TF_RETURN_IF_ERROR(update.StepAndReset()); @@ -501,79 +641,400 @@ class RunWriter { return Status::OK(); } - Env* env_; - Sqlite* db_; - IdAllocator id_allocator_; + mutex mu_; + IdAllocator* const ids_; const string experiment_name_; const string run_name_; const string user_name_; - int64 experiment_id_ = kAbsent; - int64 run_id_ = kAbsent; - int64 user_id_ = kAbsent; - std::unordered_map<string, int64> tag_ids_; - double experiment_started_time_ = 0.0; - double run_started_time_ = 0.0; - SqliteStatement insert_tensor_; + int64 experiment_id_ GUARDED_BY(mu_) = kAbsent; + int64 run_id_ GUARDED_BY(mu_) = kAbsent; + int64 user_id_ GUARDED_BY(mu_) = kAbsent; + double experiment_started_time_ GUARDED_BY(mu_) = 0.0; + double run_started_time_ GUARDED_BY(mu_) = 0.0; + std::unordered_map<string, int64> tag_ids_ GUARDED_BY(mu_); + + TF_DISALLOW_COPY_AND_ASSIGN(RunMetadata); }; +/// \brief Tensor writer for a single series, e.g. Tag. +/// +/// This class can be used to write an infinite stream of Tensors to the +/// database in a fixed block of contiguous disk space. This is +/// accomplished using Algorithm R reservoir sampling. +/// +/// The reservoir consists of a fixed number of rows, which are inserted +/// using ZEROBLOB upon receiving the first sample, which is used to +/// predict how big the other ones are likely to be. This is done +/// transactionally in a way that tries to be mindful of other processes +/// that might be trying to access the same DB. +/// +/// Once the reservoir fills up, rows are replaced at random, and writes +/// gradually become no-ops. This allows long training to go fast +/// without configuration. The exception is when someone is actually +/// looking at TensorBoard. When that happens, the "keep last" behavior +/// is turned on and Append() will always result in a write. +/// +/// If no one is watching training, this class still holds on to the +/// most recent "dangling" Tensor, so if Finish() is called, the most +/// recent training state can be written to disk. +/// +/// The randomly selected sampling points should be consistent across +/// multiple instances. +/// +/// This class is thread safe. +class SeriesWriter { + public: + SeriesWriter(int64 series, int slots, RunMetadata* meta) + : series_{series}, + slots_{slots}, + meta_{meta}, + rng_{std::mt19937_64::default_seed} { + DCHECK(series_ > 0); + DCHECK(slots_ > 0); + } + + Status Append(Sqlite* db, int64 step, uint64 now, double computed_time, + Tensor t) SQLITE_TRANSACTIONS_EXCLUDED(*db) + LOCKS_EXCLUDED(mu_) { + mutex_lock lock(mu_); + if (rowids_.empty()) { + Status s = Reserve(db, t); + if (!s.ok()) { + rowids_.clear(); + return s; + } + } + DCHECK(rowids_.size() == slots_); + int64 rowid; + size_t i = count_; + if (i < slots_) { + rowid = last_rowid_ = rowids_[i]; + } else { + i = rng_() % (i + 1); + if (i < slots_) { + rowid = last_rowid_ = rowids_[i]; + } else { + bool keep_last; + TF_RETURN_IF_ERROR(meta_->GetIsWatching(db, &keep_last)); + if (!keep_last) { + ++count_; + dangling_tensor_.reset(new Tensor(std::move(t))); + dangling_step_ = step; + dangling_computed_time_ = computed_time; + return Status::OK(); + } + rowid = last_rowid_; + } + } + Status s = Write(db, rowid, step, computed_time, t); + if (s.ok()) { + ++count_; + dangling_tensor_.reset(); + } + return s; + } + + Status Finish(Sqlite* db) SQLITE_TRANSACTIONS_EXCLUDED(*db) + LOCKS_EXCLUDED(mu_) { + mutex_lock lock(mu_); + // Short runs: Delete unused pre-allocated Tensors. + if (count_ < rowids_.size()) { + SqliteTransaction txn(*db); + const char* sql = R"sql( + DELETE FROM Tensors WHERE rowid = ? + )sql"; + SqliteStatement deleter; + TF_RETURN_IF_ERROR(db->Prepare(sql, &deleter)); + for (size_t i = count_; i < rowids_.size(); ++i) { + deleter.BindInt(1, rowids_[i]); + TF_RETURN_IF_ERROR(deleter.StepAndReset()); + } + TF_RETURN_IF_ERROR(txn.Commit()); + rowids_.clear(); + } + // Long runs: Make last sample be the very most recent one. + if (dangling_tensor_) { + DCHECK(last_rowid_ != kAbsent); + TF_RETURN_IF_ERROR(Write(db, last_rowid_, dangling_step_, + dangling_computed_time_, *dangling_tensor_)); + dangling_tensor_.reset(); + } + return Status::OK(); + } + + private: + Status Write(Sqlite* db, int64 rowid, int64 step, double computed_time, + const Tensor& t) SQLITE_TRANSACTIONS_EXCLUDED(*db) { + if (t.dtype() == DT_STRING) { + if (t.dims() == 0) { + return Update(db, step, computed_time, t, t.scalar<string>()(), rowid); + } else { + SqliteTransaction txn(*db); + TF_RETURN_IF_ERROR( + Update(db, step, computed_time, t, StringPiece(), rowid)); + TF_RETURN_IF_ERROR(UpdateNdString(db, t, rowid)); + return txn.Commit(); + } + } else { + return Update(db, step, computed_time, t, t.tensor_data(), rowid); + } + } + + Status Update(Sqlite* db, int64 step, double computed_time, const Tensor& t, + const StringPiece& data, int64 rowid) { + // TODO(jart): How can we ensure reservoir fills on replace? + const char* sql = R"sql( + UPDATE OR REPLACE + Tensors + SET + step = ?, + computed_time = ?, + dtype = ?, + shape = ?, + data = ? + WHERE + rowid = ? + )sql"; + SqliteStatement stmt; + TF_RETURN_IF_ERROR(db->Prepare(sql, &stmt)); + stmt.BindInt(1, step); + stmt.BindDouble(2, computed_time); + stmt.BindInt(3, t.dtype()); + stmt.BindText(4, StringifyShape(t.shape())); + stmt.BindBlobUnsafe(5, data); + stmt.BindInt(6, rowid); + TF_RETURN_IF_ERROR(stmt.StepAndReset()); + return Status::OK(); + } + + Status UpdateNdString(Sqlite* db, const Tensor& t, int64 tensor_rowid) + SQLITE_EXCLUSIVE_TRANSACTIONS_REQUIRED(*db) { + DCHECK_EQ(t.dtype(), DT_STRING); + DCHECK_GT(t.dims(), 0); + const char* deleter_sql = R"sql( + DELETE FROM TensorStrings WHERE tensor_rowid = ? + )sql"; + SqliteStatement deleter; + TF_RETURN_IF_ERROR(db->Prepare(deleter_sql, &deleter)); + deleter.BindInt(1, tensor_rowid); + TF_RETURN_WITH_CONTEXT_IF_ERROR(deleter.StepAndReset(), tensor_rowid); + const char* inserter_sql = R"sql( + INSERT INTO TensorStrings ( + tensor_rowid, + idx, + data + ) VALUES (?, ?, ?) + )sql"; + SqliteStatement inserter; + TF_RETURN_IF_ERROR(db->Prepare(inserter_sql, &inserter)); + auto flat = t.flat<string>(); + for (int64 i = 0; i < flat.size(); ++i) { + inserter.BindInt(1, tensor_rowid); + inserter.BindInt(2, i); + inserter.BindBlobUnsafe(3, flat(i)); + TF_RETURN_WITH_CONTEXT_IF_ERROR(inserter.StepAndReset(), "i=", i); + } + return Status::OK(); + } + + Status Reserve(Sqlite* db, const Tensor& t) SQLITE_TRANSACTIONS_EXCLUDED(*db) + EXCLUSIVE_LOCKS_REQUIRED(mu_) { + SqliteTransaction txn(*db); // only for performance + unflushed_bytes_ = 0; + if (t.dtype() == DT_STRING) { + if (t.dims() == 0) { + TF_RETURN_IF_ERROR(ReserveData(db, &txn, t.scalar<string>()().size())); + } else { + TF_RETURN_IF_ERROR(ReserveTensors(db, &txn, kReserveMinBytes)); + } + } else { + TF_RETURN_IF_ERROR(ReserveData(db, &txn, t.tensor_data().size())); + } + return txn.Commit(); + } + + Status ReserveData(Sqlite* db, SqliteTransaction* txn, size_t size) + SQLITE_EXCLUSIVE_TRANSACTIONS_REQUIRED(*db) + EXCLUSIVE_LOCKS_REQUIRED(mu_) { + int64 space = + static_cast<int64>(static_cast<double>(size) * kReserveMultiplier); + if (space < kReserveMinBytes) space = kReserveMinBytes; + return ReserveTensors(db, txn, space); + } + + Status ReserveTensors(Sqlite* db, SqliteTransaction* txn, + int64 reserved_bytes) + SQLITE_EXCLUSIVE_TRANSACTIONS_REQUIRED(*db) + EXCLUSIVE_LOCKS_REQUIRED(mu_) { + const char* sql = R"sql( + INSERT INTO Tensors ( + series, + data + ) VALUES (?, ZEROBLOB(?)) + )sql"; + SqliteStatement insert; + TF_RETURN_IF_ERROR(db->Prepare(sql, &insert)); + // TODO(jart): Maybe preallocate index pages by setting step. This + // is tricky because UPDATE OR REPLACE can have a side + // effect of deleting preallocated rows. + for (int64 i = 0; i < slots_; ++i) { + insert.BindInt(1, series_); + insert.BindInt(2, reserved_bytes); + TF_RETURN_WITH_CONTEXT_IF_ERROR(insert.StepAndReset(), "i=", i); + rowids_.push_back(db->last_insert_rowid()); + unflushed_bytes_ += reserved_bytes; + TF_RETURN_IF_ERROR(MaybeFlush(db, txn)); + } + return Status::OK(); + } + + Status MaybeFlush(Sqlite* db, SqliteTransaction* txn) + SQLITE_EXCLUSIVE_TRANSACTIONS_REQUIRED(*db) + EXCLUSIVE_LOCKS_REQUIRED(mu_) { + if (unflushed_bytes_ >= kFlushBytes) { + TF_RETURN_WITH_CONTEXT_IF_ERROR(txn->Commit(), "flushing ", + unflushed_bytes_, " bytes"); + unflushed_bytes_ = 0; + } + return Status::OK(); + } + + mutex mu_; + const int64 series_; + const int slots_; + RunMetadata* const meta_; + std::mt19937_64 rng_ GUARDED_BY(mu_); + uint64 count_ GUARDED_BY(mu_) = 0; + int64 last_rowid_ GUARDED_BY(mu_) = kAbsent; + std::vector<int64> rowids_ GUARDED_BY(mu_); + uint64 unflushed_bytes_ GUARDED_BY(mu_) = 0; + std::unique_ptr<Tensor> dangling_tensor_ GUARDED_BY(mu_); + int64 dangling_step_ GUARDED_BY(mu_) = 0; + double dangling_computed_time_ GUARDED_BY(mu_) = 0.0; + + TF_DISALLOW_COPY_AND_ASSIGN(SeriesWriter); +}; + +/// \brief Tensor writer for a single Run. +/// +/// This class farms out tensors to SeriesWriter instances. It also +/// keeps track of whether or not someone is watching the TensorBoard +/// GUI, so it can avoid writes when possible. +/// +/// This class is thread safe. +class RunWriter { + public: + explicit RunWriter(RunMetadata* meta) : meta_{meta} {} + + Status Append(Sqlite* db, int64 tag_id, int64 step, uint64 now, + double computed_time, Tensor t, int slots) + SQLITE_TRANSACTIONS_EXCLUDED(*db) LOCKS_EXCLUDED(mu_) { + SeriesWriter* writer = GetSeriesWriter(tag_id, slots); + return writer->Append(db, step, now, computed_time, std::move(t)); + } + + Status Finish(Sqlite* db) SQLITE_TRANSACTIONS_EXCLUDED(*db) + LOCKS_EXCLUDED(mu_) { + mutex_lock lock(mu_); + if (series_writers_.empty()) return Status::OK(); + for (auto i = series_writers_.begin(); i != series_writers_.end(); ++i) { + if (!i->second) continue; + TF_RETURN_WITH_CONTEXT_IF_ERROR(i->second->Finish(db), + "finish tag_id=", i->first); + i->second.reset(); + } + return Status::OK(); + } + + private: + SeriesWriter* GetSeriesWriter(int64 tag_id, int slots) LOCKS_EXCLUDED(mu_) { + mutex_lock sl(mu_); + auto spot = series_writers_.find(tag_id); + if (spot == series_writers_.end()) { + SeriesWriter* writer = new SeriesWriter(tag_id, slots, meta_); + series_writers_[tag_id].reset(writer); + return writer; + } else { + return spot->second.get(); + } + } + + mutex mu_; + RunMetadata* const meta_; + std::unordered_map<int64, std::unique_ptr<SeriesWriter>> series_writers_ + GUARDED_BY(mu_); + + TF_DISALLOW_COPY_AND_ASSIGN(RunWriter); +}; + +/// \brief SQLite implementation of SummaryWriterInterface. +/// +/// This class is thread safe. class SummaryDbWriter : public SummaryWriterInterface { public: - SummaryDbWriter(Env* env, Sqlite* db, - const string& experiment_name, const string& run_name, - const string& user_name) - : env_{env}, - run_writer_{env, db, experiment_name, run_name, user_name} {} - ~SummaryDbWriter() override {} + SummaryDbWriter(Env* env, Sqlite* db, const string& experiment_name, + const string& run_name, const string& user_name) + : SummaryWriterInterface(), + env_{env}, + db_{db}, + ids_{env_, db_}, + meta_{&ids_, experiment_name, run_name, user_name}, + run_{&meta_} { + DCHECK(env_ != nullptr); + db_->Ref(); + } + + ~SummaryDbWriter() override { + core::ScopedUnref unref(db_); + Status s = run_.Finish(db_); + if (!s.ok()) { + // TODO(jart): Retry on transient errors here. + LOG(ERROR) << s.ToString(); + } + int64 run_id = meta_.run_id(); + if (run_id == kAbsent) return; + const char* sql = R"sql( + UPDATE Runs SET finished_time = ? WHERE run_id = ? + )sql"; + SqliteStatement update; + s = db_->Prepare(sql, &update); + if (s.ok()) { + update.BindDouble(1, DoubleTime(env_->NowMicros())); + update.BindInt(2, run_id); + s = update.StepAndReset(); + } + if (!s.ok()) { + LOG(ERROR) << "Failed to set Runs[" << run_id + << "].finish_time: " << s.ToString(); + } + } Status Flush() override { return Status::OK(); } Status WriteTensor(int64 global_step, Tensor t, const string& tag, const string& serialized_metadata) override { - mutex_lock ml(mu_); + TF_RETURN_IF_ERROR(CheckSupportedType(t)); SummaryMetadata metadata; - if (!serialized_metadata.empty()) { - metadata.ParseFromString(serialized_metadata); + if (!metadata.ParseFromString(serialized_metadata)) { + return errors::InvalidArgument("Bad serialized_metadata"); } - double now = GetWallTime(env_); - int64 tag_id; - TF_RETURN_IF_ERROR(run_writer_.GetTagId(now, tag, metadata, &tag_id)); - return run_writer_.InsertTensor(tag_id, global_step, now, t); + return Write(global_step, t, tag, metadata); } Status WriteScalar(int64 global_step, Tensor t, const string& tag) override { - Tensor t2; - TF_RETURN_IF_ERROR(CoerceScalar(t, &t2)); - // TODO(jart): Generate scalars plugin metadata on this value. - return WriteTensor(global_step, std::move(t2), tag, ""); + TF_RETURN_IF_ERROR(CheckSupportedType(t)); + SummaryMetadata metadata; + PatchPluginName(&metadata, kScalarPluginName); + return Write(global_step, AsScalar(t), tag, metadata); } Status WriteGraph(int64 global_step, std::unique_ptr<GraphDef> g) override { - mutex_lock ml(mu_); - return run_writer_.InsertGraph(std::move(g), GetWallTime(env_)); + uint64 now = env_->NowMicros(); + return meta_.SetGraph(db_, now, DoubleTime(now), std::move(g)); } Status WriteEvent(std::unique_ptr<Event> e) override { - switch (e->what_case()) { - case Event::WhatCase::kSummary: { - mutex_lock ml(mu_); - Status s; - for (const auto& value : e->summary().value()) { - s.Update(WriteSummary(e.get(), value)); - } - return s; - } - case Event::WhatCase::kGraphDef: { - mutex_lock ml(mu_); - std::unique_ptr<GraphDef> graph{new GraphDef}; - if (!ParseProtoUnlimited(graph.get(), e->graph_def())) { - return errors::DataLoss("parse event.graph_def failed"); - } - return run_writer_.InsertGraph(std::move(graph), e->wall_time()); - } - default: - // TODO(@jart): Handle other stuff. - return Status::OK(); - } + return MigrateEvent(std::move(e)); } Status WriteHistogram(int64 global_step, Tensor t, @@ -600,26 +1061,165 @@ class SummaryDbWriter : public SummaryWriterInterface { string DebugString() override { return "SummaryDbWriter"; } private: - Status WriteSummary(const Event* e, const Summary::Value& summary) - EXCLUSIVE_LOCKS_REQUIRED(mu_) { - switch (summary.value_case()) { - case Summary::Value::ValueCase::kSimpleValue: { - int64 tag_id; - TF_RETURN_IF_ERROR(run_writer_.GetTagId(e->wall_time(), summary.tag(), - summary.metadata(), &tag_id)); - Tensor t{DT_DOUBLE, {}}; - t.scalar<double>()() = summary.simple_value(); - return run_writer_.InsertTensor(tag_id, e->step(), e->wall_time(), t); + Status Write(int64 step, const Tensor& t, const string& tag, + const SummaryMetadata& metadata) { + uint64 now = env_->NowMicros(); + double computed_time = DoubleTime(now); + int64 tag_id; + TF_RETURN_IF_ERROR( + meta_.GetTagId(db_, now, computed_time, tag, &tag_id, metadata)); + TF_RETURN_WITH_CONTEXT_IF_ERROR( + run_.Append(db_, tag_id, step, now, computed_time, t, + GetSlots(t, metadata)), + meta_.user_name(), "/", meta_.experiment_name(), "/", meta_.run_name(), + "/", tag, "@", step); + return Status::OK(); + } + + Status MigrateEvent(std::unique_ptr<Event> e) { + switch (e->what_case()) { + case Event::WhatCase::kSummary: { + uint64 now = env_->NowMicros(); + auto summaries = e->mutable_summary(); + for (int i = 0; i < summaries->value_size(); ++i) { + Summary::Value* value = summaries->mutable_value(i); + TF_RETURN_WITH_CONTEXT_IF_ERROR( + MigrateSummary(e.get(), value, now), meta_.user_name(), "/", + meta_.experiment_name(), "/", meta_.run_name(), "/", value->tag(), + "@", e->step()); + } + break; } + case Event::WhatCase::kGraphDef: + TF_RETURN_WITH_CONTEXT_IF_ERROR( + MigrateGraph(e.get(), e->graph_def()), meta_.user_name(), "/", + meta_.experiment_name(), "/", meta_.run_name(), "/__graph__@", + e->step()); + break; default: - // TODO(@jart): Handle the rest. - return Status::OK(); + // TODO(@jart): Handle other stuff. + break; } + return Status::OK(); } - mutex mu_; - Env* env_; - RunWriter run_writer_ GUARDED_BY(mu_); + Status MigrateGraph(const Event* e, const string& graph_def) { + uint64 now = env_->NowMicros(); + std::unique_ptr<GraphDef> graph{new GraphDef}; + if (!ParseProtoUnlimited(graph.get(), graph_def)) { + return errors::InvalidArgument("bad proto"); + } + return meta_.SetGraph(db_, now, e->wall_time(), std::move(graph)); + } + + Status MigrateSummary(const Event* e, Summary::Value* s, uint64 now) { + switch (s->value_case()) { + case Summary::Value::ValueCase::kTensor: + TF_RETURN_WITH_CONTEXT_IF_ERROR(MigrateTensor(e, s, now), "tensor"); + break; + case Summary::Value::ValueCase::kSimpleValue: + TF_RETURN_WITH_CONTEXT_IF_ERROR(MigrateScalar(e, s, now), "scalar"); + break; + case Summary::Value::ValueCase::kHisto: + TF_RETURN_WITH_CONTEXT_IF_ERROR(MigrateHistogram(e, s, now), "histo"); + break; + case Summary::Value::ValueCase::kImage: + TF_RETURN_WITH_CONTEXT_IF_ERROR(MigrateImage(e, s, now), "image"); + break; + case Summary::Value::ValueCase::kAudio: + TF_RETURN_WITH_CONTEXT_IF_ERROR(MigrateAudio(e, s, now), "audio"); + break; + default: + break; + } + return Status::OK(); + } + + Status MigrateTensor(const Event* e, Summary::Value* s, uint64 now) { + Tensor t; + if (!t.FromProto(s->tensor())) return errors::InvalidArgument("bad proto"); + TF_RETURN_IF_ERROR(CheckSupportedType(t)); + int64 tag_id; + TF_RETURN_IF_ERROR(meta_.GetTagId(db_, now, e->wall_time(), s->tag(), + &tag_id, s->metadata())); + return run_.Append(db_, tag_id, e->step(), now, e->wall_time(), t, + GetSlots(t, s->metadata())); + } + + // TODO(jart): Refactor Summary -> Tensor logic into separate file. + + Status MigrateScalar(const Event* e, Summary::Value* s, uint64 now) { + // See tensorboard/plugins/scalar/summary.py and data_compat.py + Tensor t{DT_FLOAT, {}}; + t.scalar<float>()() = s->simple_value(); + int64 tag_id; + PatchPluginName(s->mutable_metadata(), kScalarPluginName); + TF_RETURN_IF_ERROR(meta_.GetTagId(db_, now, e->wall_time(), s->tag(), + &tag_id, s->metadata())); + return run_.Append(db_, tag_id, e->step(), now, e->wall_time(), + std::move(t), kScalarSlots); + } + + Status MigrateHistogram(const Event* e, Summary::Value* s, uint64 now) { + const HistogramProto& histo = s->histo(); + int k = histo.bucket_size(); + if (k != histo.bucket_limit_size()) { + return errors::InvalidArgument("size mismatch"); + } + // See tensorboard/plugins/histogram/summary.py and data_compat.py + Tensor t{DT_DOUBLE, {k, 3}}; + auto data = t.flat<double>(); + for (int i = 0; i < k; ++i) { + double left_edge = ((i - 1 >= 0) ? histo.bucket_limit(i - 1) + : std::numeric_limits<double>::min()); + double right_edge = ((i + 1 < k) ? histo.bucket_limit(i + 1) + : std::numeric_limits<double>::max()); + data(i + 0) = left_edge; + data(i + 1) = right_edge; + data(i + 2) = histo.bucket(i); + } + int64 tag_id; + PatchPluginName(s->mutable_metadata(), kHistogramPluginName); + TF_RETURN_IF_ERROR(meta_.GetTagId(db_, now, e->wall_time(), s->tag(), + &tag_id, s->metadata())); + return run_.Append(db_, tag_id, e->step(), now, e->wall_time(), + std::move(t), kHistogramSlots); + } + + Status MigrateImage(const Event* e, Summary::Value* s, uint64 now) { + // See tensorboard/plugins/image/summary.py and data_compat.py + Tensor t{DT_STRING, {3}}; + auto img = s->mutable_image(); + t.flat<string>()(0) = strings::StrCat(img->width()); + t.flat<string>()(1) = strings::StrCat(img->height()); + t.flat<string>()(2) = std::move(*img->mutable_encoded_image_string()); + int64 tag_id; + PatchPluginName(s->mutable_metadata(), kImagePluginName); + TF_RETURN_IF_ERROR(meta_.GetTagId(db_, now, e->wall_time(), s->tag(), + &tag_id, s->metadata())); + return run_.Append(db_, tag_id, e->step(), now, e->wall_time(), + std::move(t), kImageSlots); + } + + Status MigrateAudio(const Event* e, Summary::Value* s, uint64 now) { + // See tensorboard/plugins/audio/summary.py and data_compat.py + Tensor t{DT_STRING, {1, 2}}; + auto wav = s->mutable_audio(); + t.flat<string>()(0) = std::move(*wav->mutable_encoded_audio_string()); + t.flat<string>()(1) = ""; + int64 tag_id; + PatchPluginName(s->mutable_metadata(), kAudioPluginName); + TF_RETURN_IF_ERROR(meta_.GetTagId(db_, now, e->wall_time(), s->tag(), + &tag_id, s->metadata())); + return run_.Append(db_, tag_id, e->step(), now, e->wall_time(), + std::move(t), kAudioSlots); + } + + Env* const env_; + Sqlite* const db_; + IdAllocator ids_; + RunMetadata meta_; + RunWriter run_; }; } // namespace @@ -627,8 +1227,6 @@ class SummaryDbWriter : public SummaryWriterInterface { Status CreateSummaryDbWriter(Sqlite* db, const string& experiment_name, const string& run_name, const string& user_name, Env* env, SummaryWriterInterface** result) { - *result = nullptr; - TF_RETURN_IF_ERROR(SetupTensorboardSqliteDb(db)); *result = new SummaryDbWriter(env, db, experiment_name, run_name, user_name); return Status::OK(); } diff --git a/tensorflow/contrib/tensorboard/db/summary_db_writer.h b/tensorflow/contrib/tensorboard/db/summary_db_writer.h index 5a3de195de..746da1533b 100644 --- a/tensorflow/contrib/tensorboard/db/summary_db_writer.h +++ b/tensorflow/contrib/tensorboard/db/summary_db_writer.h @@ -19,16 +19,15 @@ limitations under the License. #include "tensorflow/core/lib/core/status.h" #include "tensorflow/core/lib/db/sqlite.h" #include "tensorflow/core/platform/env.h" -#include "tensorflow/core/platform/types.h" namespace tensorflow { /// \brief Creates SQLite SummaryWriterInterface. /// /// This can be used to write tensors from the execution graph directly -/// to a database. The schema will be created automatically, but only -/// if necessary. Entries in the Users, Experiments, and Runs tables -/// will be created automatically if they don't already exist. +/// to a database. The schema must be created beforehand. Entries in +/// Users, Experiments, and Runs tables will be created automatically +/// if they don't already exist. /// /// Please note that the type signature of this function may change in /// the future if support for other DBs is added to core. diff --git a/tensorflow/contrib/tensorboard/db/summary_db_writer_test.cc b/tensorflow/contrib/tensorboard/db/summary_db_writer_test.cc index 68444c35be..29b8063218 100644 --- a/tensorflow/contrib/tensorboard/db/summary_db_writer_test.cc +++ b/tensorflow/contrib/tensorboard/db/summary_db_writer_test.cc @@ -14,6 +14,8 @@ limitations under the License. ==============================================================================*/ #include "tensorflow/contrib/tensorboard/db/summary_db_writer.h" +#include "tensorflow/contrib/tensorboard/db/schema.h" +#include "tensorflow/core/framework/function.pb.h" #include "tensorflow/core/framework/graph.pb.h" #include "tensorflow/core/framework/node_def.pb.h" #include "tensorflow/core/framework/summary.pb.h" @@ -27,8 +29,6 @@ limitations under the License. namespace tensorflow { namespace { -const float kTolerance = 1e-5; - Tensor MakeScalarInt64(int64 x) { Tensor t(DT_INT64, TensorShape({})); t.scalar<int64>()() = x; @@ -50,6 +50,7 @@ class SummaryDbWriterTest : public ::testing::Test { protected: void SetUp() override { TF_ASSERT_OK(Sqlite::Open(":memory:", SQLITE_OPEN_READWRITE, &db_)); + TF_ASSERT_OK(SetupTensorboardSqliteDb(db_)); } void TearDown() override { @@ -138,7 +139,7 @@ TEST_F(SummaryDbWriterTest, TensorsWritten_RowsGetInitialized) { ASSERT_EQ(1LL, QueryInt("SELECT COUNT(*) FROM Experiments")); ASSERT_EQ(1LL, QueryInt("SELECT COUNT(*) FROM Runs")); ASSERT_EQ(1LL, QueryInt("SELECT COUNT(*) FROM Tags")); - ASSERT_EQ(2LL, QueryInt("SELECT COUNT(*) FROM Tensors")); + ASSERT_EQ(10000LL, QueryInt("SELECT COUNT(*) FROM Tensors")); int64 user_id = QueryInt("SELECT user_id FROM Users"); int64 experiment_id = QueryInt("SELECT experiment_id FROM Experiments"); @@ -170,17 +171,13 @@ TEST_F(SummaryDbWriterTest, TensorsWritten_RowsGetInitialized) { EXPECT_EQ("plugin_data", QueryString("SELECT plugin_data FROM Tags")); EXPECT_EQ("description", QueryString("SELECT description FROM Descriptions")); - EXPECT_EQ(tag_id, QueryInt("SELECT tag_id FROM Tensors WHERE step = 1")); + EXPECT_EQ(tag_id, QueryInt("SELECT series FROM Tensors WHERE step = 1")); EXPECT_EQ(0.023, QueryDouble("SELECT computed_time FROM Tensors WHERE step = 1")); - EXPECT_FALSE( - QueryString("SELECT tensor FROM Tensors WHERE step = 1").empty()); - EXPECT_EQ(tag_id, QueryInt("SELECT tag_id FROM Tensors WHERE step = 2")); + EXPECT_EQ(tag_id, QueryInt("SELECT series FROM Tensors WHERE step = 2")); EXPECT_EQ(0.046, QueryDouble("SELECT computed_time FROM Tensors WHERE step = 2")); - EXPECT_FALSE( - QueryString("SELECT tensor FROM Tensors WHERE step = 2").empty()); } TEST_F(SummaryDbWriterTest, EmptyParentNames_NoParentsCreated) { @@ -191,7 +188,7 @@ TEST_F(SummaryDbWriterTest, EmptyParentNames_NoParentsCreated) { ASSERT_EQ(0LL, QueryInt("SELECT COUNT(*) FROM Experiments")); ASSERT_EQ(0LL, QueryInt("SELECT COUNT(*) FROM Runs")); ASSERT_EQ(1LL, QueryInt("SELECT COUNT(*) FROM Tags")); - ASSERT_EQ(1LL, QueryInt("SELECT COUNT(*) FROM Tensors")); + ASSERT_EQ(10000LL, QueryInt("SELECT COUNT(*) FROM Tensors")); } TEST_F(SummaryDbWriterTest, WriteEvent_Scalar) { @@ -208,33 +205,24 @@ TEST_F(SummaryDbWriterTest, WriteEvent_Scalar) { TF_ASSERT_OK(writer_->WriteEvent(std::move(e))); TF_ASSERT_OK(writer_->Flush()); ASSERT_EQ(2LL, QueryInt("SELECT COUNT(*) FROM Tags")); - ASSERT_EQ(2LL, QueryInt("SELECT COUNT(*) FROM Tensors")); + ASSERT_EQ(20000LL, QueryInt("SELECT COUNT(*) FROM Tensors")); int64 tag1_id = QueryInt("SELECT tag_id FROM Tags WHERE tag_name = 'π'"); int64 tag2_id = QueryInt("SELECT tag_id FROM Tags WHERE tag_name = 'φ'"); EXPECT_GT(tag1_id, 0LL); EXPECT_GT(tag2_id, 0LL); EXPECT_EQ(123.456, QueryDouble(strings::StrCat( - "SELECT computed_time FROM Tensors WHERE tag_id = ", + "SELECT computed_time FROM Tensors WHERE series = ", tag1_id, " AND step = 7"))); EXPECT_EQ(123.456, QueryDouble(strings::StrCat( - "SELECT computed_time FROM Tensors WHERE tag_id = ", + "SELECT computed_time FROM Tensors WHERE series = ", tag2_id, " AND step = 7"))); - EXPECT_NEAR(3.14, - QueryDouble(strings::StrCat( - "SELECT tensor FROM Tensors WHERE tag_id = ", tag1_id, - " AND step = 7")), - kTolerance); // Summary::simple_value is float - EXPECT_NEAR(1.61, - QueryDouble(strings::StrCat( - "SELECT tensor FROM Tensors WHERE tag_id = ", tag2_id, - " AND step = 7")), - kTolerance); } TEST_F(SummaryDbWriterTest, WriteGraph) { TF_ASSERT_OK(CreateSummaryDbWriter(db_, "", "R", "", &env_, &writer_)); env_.AdvanceByMillis(23); GraphDef graph; + graph.mutable_library()->add_gradient()->set_function_name("funk"); NodeDef* node = graph.add_node(); node->set_name("x"); node->set_op("Placeholder"); @@ -260,11 +248,17 @@ TEST_F(SummaryDbWriterTest, WriteGraph) { ASSERT_EQ(4LL, QueryInt("SELECT COUNT(*) FROM Nodes")); ASSERT_EQ(3LL, QueryInt("SELECT COUNT(*) FROM NodeInputs")); + ASSERT_EQ(QueryInt("SELECT run_id FROM Runs"), + QueryInt("SELECT run_id FROM Graphs")); + int64 graph_id = QueryInt("SELECT graph_id FROM Graphs"); EXPECT_GT(graph_id, 0LL); - EXPECT_EQ(graph_id, QueryInt("SELECT graph_id FROM Runs")); EXPECT_EQ(0.023, QueryDouble("SELECT inserted_time FROM Graphs")); - EXPECT_FALSE(QueryString("SELECT graph_def FROM Graphs").empty()); + + GraphDef graph2; + graph2.ParseFromString(QueryString("SELECT graph_def FROM Graphs")); + EXPECT_EQ(0, graph2.node_size()); + EXPECT_EQ("funk", graph2.library().gradient(0).function_name()); EXPECT_EQ("x", QueryString("SELECT node_name FROM Nodes WHERE node_id = 0")); EXPECT_EQ("y", QueryString("SELECT node_name FROM Nodes WHERE node_id = 1")); @@ -307,33 +301,6 @@ TEST_F(SummaryDbWriterTest, WriteGraph) { EXPECT_EQ(1LL, QueryInt("SELECT is_control FROM NodeInputs WHERE idx = 2")); } -TEST_F(SummaryDbWriterTest, WriteScalarInt32_CoercesToInt64) { - TF_ASSERT_OK(CreateSummaryDbWriter(db_, "", "", "", &env_, &writer_)); - Tensor t(DT_INT32, {}); - t.scalar<int32>()() = -17; - TF_ASSERT_OK(writer_->WriteScalar(1, t, "t")); - TF_ASSERT_OK(writer_->Flush()); - ASSERT_EQ(-17LL, QueryInt("SELECT tensor FROM Tensors")); -} - -TEST_F(SummaryDbWriterTest, WriteScalarInt8_CoercesToInt64) { - TF_ASSERT_OK(CreateSummaryDbWriter(db_, "", "", "", &env_, &writer_)); - Tensor t(DT_INT8, {}); - t.scalar<int8>()() = static_cast<int8>(-17); - TF_ASSERT_OK(writer_->WriteScalar(1, t, "t")); - TF_ASSERT_OK(writer_->Flush()); - ASSERT_EQ(-17LL, QueryInt("SELECT tensor FROM Tensors")); -} - -TEST_F(SummaryDbWriterTest, WriteScalarUint8_CoercesToInt64) { - TF_ASSERT_OK(CreateSummaryDbWriter(db_, "", "", "", &env_, &writer_)); - Tensor t(DT_UINT8, {}); - t.scalar<uint8>()() = static_cast<uint8>(254); - TF_ASSERT_OK(writer_->WriteScalar(1, t, "t")); - TF_ASSERT_OK(writer_->Flush()); - ASSERT_EQ(254LL, QueryInt("SELECT tensor FROM Tensors")); -} - TEST_F(SummaryDbWriterTest, UsesIdsTable) { SummaryMetadata metadata; TF_ASSERT_OK(CreateSummaryDbWriter(db_, "mad-science", "train", "jart", &env_, |