aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/contrib/tensorboard
diff options
context:
space:
mode:
authorGravatar Justine Tunney <jart@google.com>2017-12-07 16:16:36 -0800
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2017-12-07 16:21:34 -0800
commit6c4af6202c3984d7eabc8044c43579315c4b07a2 (patch)
tree3fcc6332f12604deecd91fd84b073897a1d04852 /tensorflow/contrib/tensorboard
parent6e04085f90c5c0c2a49723cc682b16327c994957 (diff)
Make SQLite random IDs unique across tables
We now have an invariant that no two IDs are the same across tables. We also assume that only one instance will ever write tensors to a given run. This approach to IDs also allows us to be more carefree about garbage data, so we don't need the transactions anymore. It also brings latency down from 16ms to 8ms. Name Cold ?s Average ?s Flushing ?s Size B ?i.i 10,020 264 0 0 Scalar 1.0 FS 7,996 711 4,808 11,348 Scalar 1.0 TB FS 14,875 891 5,487 17,023 Scalar 2.0 FS 13,123 891 4,499 11,348 Scalar 2.0 DB 72,497 8,472 8,875 118,784 Tensor 1.0 FS 4 16,128 856 4,785 14,215 Tensor 2.0 FS 4 23,765 1,032 4,508 24,455 Tensor 2.0 DB 4 91,735 8,407 8,175 118,784 Tensor 1.0 FS 128 18,592 831 4,950 47,111 Tensor 2.0 FS 128 18,174 1,033 4,498 57,351 Tensor 2.0 DB 128 98,045 17,799 8,710 118,784 Tensor 1.0 FS 8192 19,225 1,164 5,217 2,119,816 Tensor 2.0 FS 8192 16,979 921 4,360 2,130,056 Tensor 2.0 DB 8192 108,704 8,470 8,543 126,976 PiperOrigin-RevId: 178312341
Diffstat (limited to 'tensorflow/contrib/tensorboard')
-rw-r--r--tensorflow/contrib/tensorboard/db/schema.cc93
-rw-r--r--tensorflow/contrib/tensorboard/db/summary_db_writer.cc548
-rw-r--r--tensorflow/contrib/tensorboard/db/summary_db_writer_test.cc61
3 files changed, 448 insertions, 254 deletions
diff --git a/tensorflow/contrib/tensorboard/db/schema.cc b/tensorflow/contrib/tensorboard/db/schema.cc
index d63b2c6cc2..fd024d692c 100644
--- a/tensorflow/contrib/tensorboard/db/schema.cc
+++ b/tensorflow/contrib/tensorboard/db/schema.cc
@@ -21,6 +21,48 @@ class SqliteSchema {
public:
explicit SqliteSchema(std::shared_ptr<Sqlite> db) : db_(std::move(db)) {}
+ /// \brief Creates Ids table.
+ ///
+ /// This table must be used to randomly allocate Permanent IDs for
+ /// all top-level tables, in order to maintain an invariant where
+ /// foo_id != bar_id for all IDs of any two tables.
+ ///
+ /// A row should only be deleted from this table if it can be
+ /// guaranteed that it exists absolutely nowhere else in the entire
+ /// system.
+ ///
+ /// Fields:
+ /// id: An ID that was allocated globally. This must be in the
+ /// range [1,2**47). 0 is assigned the same meaning as NULL and
+ /// shouldn't be stored; 2**63-1 is reserved for statically
+ /// allocating space in a page to UPDATE later; and all other
+ /// int64 values are reserved for future use.
+ Status CreateIdsTable() {
+ return Run(R"sql(
+ CREATE TABLE IF NOT EXISTS Ids (
+ id INTEGER PRIMARY KEY
+ )
+ )sql");
+ }
+
+ /// \brief Creates Descriptions table.
+ ///
+ /// This table allows TensorBoard to associate Markdown text with any
+ /// object in the database that has a Permanent ID.
+ ///
+ /// Fields:
+ /// id: The Permanent ID of the associated object. This is also the
+ /// SQLite rowid.
+ /// description: Arbitrary Markdown text.
+ Status CreateDescriptionsTable() {
+ return Run(R"sql(
+ CREATE TABLE IF NOT EXISTS Descriptions (
+ id INTEGER PRIMARY KEY,
+ description TEXT
+ )
+ )sql");
+ }
+
/// \brief Creates Tensors table.
///
/// Fields:
@@ -83,15 +125,15 @@ class SqliteSchema {
///
/// Fields:
/// rowid: Ephemeral b-tree ID dictating locality.
- /// tag_id: Permanent >0 unique ID.
+ /// tag_id: The Permanent ID of the Tag.
/// run_id: Optional ID of associated Run.
/// tag_name: The tag field in summary.proto, unique across Run.
/// inserted_time: Float UNIX timestamp with µs precision. This is
/// always the wall time of when the row was inserted into the
/// DB. It may be used as a hint for an archival job.
- /// metadata: Optional BLOB of SummaryMetadata proto.
/// display_name: Optional for GUI and defaults to tag_name.
- /// summary_description: Optional markdown information.
+ /// plugin_name: Arbitrary TensorBoard plugin name for dispatch.
+ /// plugin_data: Arbitrary data that plugin wants.
Status CreateTagsTable() {
return Run(R"sql(
CREATE TABLE IF NOT EXISTS Tags (
@@ -100,28 +142,31 @@ class SqliteSchema {
tag_id INTEGER NOT NULL,
tag_name TEXT,
inserted_time DOUBLE,
- metadata BLOB,
display_name TEXT,
- description TEXT
+ plugin_name TEXT,
+ plugin_data BLOB
)
)sql");
}
/// \brief Creates Runs table.
///
- /// This table stores information about runs. Each row usually
+ /// This table stores information about Runs. Each row usually
/// represents a single attempt at training or testing a TensorFlow
/// model, with a given set of hyper-parameters, whose summaries are
/// written out to a single event logs directory with a monotonic step
/// counter.
///
- /// When a run is deleted from this table, TensorBoard should treat all
- /// information associated with it as deleted, even if those rows in
- /// different tables still exist.
- ///
/// Fields:
/// rowid: Ephemeral b-tree ID dictating locality.
- /// run_id: Permanent >0 unique ID.
+ /// run_id: The Permanent ID of the Run. This has a 1:1 mapping
+ /// with a SummaryWriter instance. If two writers spawn for a
+ /// given (user_name, run_name, run_name) then each should
+ /// allocate its own run_id and whichever writer puts it in the
+ /// database last wins. The Tags / Tensors associated with the
+ /// previous invocations will then enter limbo, where they may be
+ /// accessible for certain operations, but should be garbage
+ /// collected eventually.
/// experiment_id: Optional ID of associated Experiment.
/// run_name: User-supplied string, unique across Experiment.
/// inserted_time: Float UNIX timestamp with µs precision. This is
@@ -134,7 +179,10 @@ class SqliteSchema {
/// started, from the perspective of whichever machine talks to
/// the database. This field will be mutated if the run is
/// restarted.
- /// description: Optional markdown information.
+ /// finished_time: Float UNIX timestamp with µs precision of when
+ /// SummaryWriter resource that created this run was destroyed.
+ /// Once this value becomes non-NULL a Run and its Tags and
+ /// Tensors should be regarded as immutable.
/// graph_id: ID of associated Graphs row.
Status CreateRunsTable() {
return Run(R"sql(
@@ -145,7 +193,7 @@ class SqliteSchema {
run_name TEXT,
inserted_time REAL,
started_time REAL,
- description TEXT,
+ finished_time REAL,
graph_id INTEGER
)
)sql");
@@ -159,15 +207,15 @@ class SqliteSchema {
/// Fields:
/// rowid: Ephemeral b-tree ID dictating locality.
/// user_id: Optional ID of associated User.
- /// experiment_id: Permanent >0 unique ID.
+ /// experiment_id: The Permanent ID of the Experiment.
/// experiment_name: User-supplied string, unique across User.
/// inserted_time: Float UNIX timestamp with µs precision. This is
/// always the time the row was inserted into the database. It
/// does not change.
/// started_time: Float UNIX timestamp with µs precision. This is
/// the MIN(experiment.started_time, run.started_time) of each
- /// Run added to the database.
- /// description: Optional markdown information.
+ /// Run added to the database, including Runs which have since
+ /// been overwritten.
Status CreateExperimentsTable() {
return Run(R"sql(
CREATE TABLE IF NOT EXISTS Experiments (
@@ -176,8 +224,7 @@ class SqliteSchema {
experiment_id INTEGER NOT NULL,
experiment_name TEXT,
inserted_time REAL,
- started_time REAL,
- description TEXT
+ started_time REAL
)
)sql");
}
@@ -186,7 +233,7 @@ class SqliteSchema {
///
/// Fields:
/// rowid: Ephemeral b-tree ID dictating locality.
- /// user_id: Permanent >0 unique ID.
+ /// user_id: The Permanent ID of the User.
/// user_name: Unique user name.
/// email: Optional unique email address.
/// inserted_time: Float UNIX timestamp with µs precision. This is
@@ -208,7 +255,7 @@ class SqliteSchema {
///
/// Fields:
/// rowid: Ephemeral b-tree ID dictating locality.
- /// graph_id: Permanent >0 unique ID.
+ /// graph_id: The Permanent ID of the Graph.
/// inserted_time: Float UNIX timestamp with µs precision. This is
/// always the wall time of when the row was inserted into the
/// DB. It may be used as a hint for an archival job.
@@ -229,7 +276,7 @@ class SqliteSchema {
///
/// Fields:
/// rowid: Ephemeral b-tree ID dictating locality.
- /// graph_id: Permanent >0 unique ID.
+ /// graph_id: The Permanent ID of the associated Graph.
/// node_id: ID for this node. This is more like a 0-index within
/// the Graph. Please note indexes are allowed to be removed.
/// node_name: Unique name for this Node within Graph. This is
@@ -258,7 +305,7 @@ class SqliteSchema {
///
/// Fields:
/// rowid: Ephemeral b-tree ID dictating locality.
- /// graph_id: Permanent >0 unique ID.
+ /// graph_id: The Permanent ID of the associated Graph.
/// node_id: Index of Node in question. This can be considered the
/// 'to' vertex.
/// idx: Used for ordering inputs on a given Node.
@@ -420,6 +467,8 @@ class SqliteSchema {
Status SetupTensorboardSqliteDb(std::shared_ptr<Sqlite> db) {
SqliteSchema s(std::move(db));
+ TF_RETURN_IF_ERROR(s.CreateIdsTable());
+ TF_RETURN_IF_ERROR(s.CreateDescriptionsTable());
TF_RETURN_IF_ERROR(s.CreateTensorsTable());
TF_RETURN_IF_ERROR(s.CreateTensorChunksTable());
TF_RETURN_IF_ERROR(s.CreateTagsTable());
diff --git a/tensorflow/contrib/tensorboard/db/summary_db_writer.cc b/tensorflow/contrib/tensorboard/db/summary_db_writer.cc
index 37a32acb1e..04b9c8e457 100644
--- a/tensorflow/contrib/tensorboard/db/summary_db_writer.cc
+++ b/tensorflow/contrib/tensorboard/db/summary_db_writer.cc
@@ -29,22 +29,25 @@ limitations under the License.
namespace tensorflow {
namespace {
+// https://www.sqlite.org/fileformat.html#record_format
+const uint64 kIdTiers[] = {
+ 0x7fffffULL, // 23-bit (3 bytes on disk)
+ 0x7fffffffULL, // 31-bit (4 bytes on disk)
+ 0x7fffffffffffULL, // 47-bit (5 bytes on disk)
+ // Remaining bits reserved for future use.
+};
+const int kMaxIdTier = sizeof(kIdTiers) / sizeof(uint64);
+const int kIdCollisionDelayMicros = 10;
+const int kMaxIdCollisions = 21; // sum(2**i*10µs for i in range(21))~=21s
+const int64 kAbsent = 0LL;
+const int64 kReserved = 0x7fffffffffffffffLL;
+
double GetWallTime(Env* env) {
// TODO(@jart): Follow precise definitions for time laid out in schema.
// TODO(@jart): Use monotonic clock from gRPC codebase.
return static_cast<double>(env->NowMicros()) / 1.0e6;
}
-int64 MakeRandomId() {
- // TODO(@jart): Try generating ID in 2^24 space, falling back to 2^63
- // https://sqlite.org/src4/doc/trunk/www/varint.wiki
- int64 id = static_cast<int64>(random::New64() & ((1ULL << 63) - 1));
- if (id == 0) {
- ++id;
- }
- return id;
-}
-
Status Serialize(const protobuf::MessageLite& proto, string* output) {
output->clear();
if (!proto.SerializeToString(output)) {
@@ -130,54 +133,69 @@ Status CoerceScalar(const Tensor& t, Tensor* out) {
return Status::OK();
}
-class Transactor {
+/// \brief Generates unique IDs randomly in the [1,2**63-2] range.
+///
+/// This class starts off generating IDs in the [1,2**23-1] range,
+/// because it's human friendly and occupies 4 bytes max on disk with
+/// SQLite's zigzag varint encoding. Then, each time a collision
+/// happens, the random space is increased by 8 bits.
+///
+/// This class uses exponential back-off so writes will slow down as
+/// the ID space becomes exhausted.
+class IdAllocator {
public:
- explicit Transactor(std::shared_ptr<Sqlite> db)
- : db_(std::move(db)),
- begin_(db_->Prepare("BEGIN TRANSACTION")),
- commit_(db_->Prepare("COMMIT TRANSACTION")),
- rollback_(db_->Prepare("ROLLBACK TRANSACTION")) {}
-
- template <typename T, typename... Args>
- Status Transact(T callback, Args&&... args) {
- TF_RETURN_IF_ERROR(begin_.StepAndReset());
- Status s = callback(std::forward<Args>(args)...);
- if (s.ok()) {
- TF_RETURN_IF_ERROR(commit_.StepAndReset());
- } else {
- TF_RETURN_WITH_CONTEXT_IF_ERROR(rollback_.StepAndReset(), s.ToString());
+ IdAllocator(Env* env, Sqlite* db)
+ : env_{env}, inserter_{db->Prepare("INSERT INTO Ids (id) VALUES (?)")} {}
+
+ Status CreateNewId(int64* id) {
+ Status s;
+ for (int i = 0; i < kMaxIdCollisions; ++i) {
+ int64 tid = MakeRandomId();
+ inserter_.BindInt(1, tid);
+ s = inserter_.StepAndReset();
+ if (s.ok()) {
+ *id = tid;
+ break;
+ }
+ // SQLITE_CONSTRAINT maps to INVALID_ARGUMENT in sqlite.cc
+ if (s.code() != error::INVALID_ARGUMENT) break;
+ if (tier_ < kMaxIdTier) {
+ LOG(INFO) << "IdAllocator collision at tier " << tier_ << " (of "
+ << kMaxIdTier << ") so auto-adjusting to a higher tier";
+ ++tier_;
+ } else {
+ LOG(WARNING) << "IdAllocator (attempt #" << i << ") "
+ << "resulted in a collision at the highest tier; this "
+ "is problematic if it happens often; you can try "
+ "pruning the Ids table; you can also file a bug "
+ "asking for the ID space to be increased; otherwise "
+ "writes will gradually slow down over time until they "
+ "become impossible";
+ }
+ env_->SleepForMicroseconds((1 << i) * kIdCollisionDelayMicros);
}
return s;
}
private:
- std::shared_ptr<Sqlite> db_;
- SqliteStatement begin_;
- SqliteStatement commit_;
- SqliteStatement rollback_;
+ int64 MakeRandomId() {
+ int64 id = static_cast<int64>(random::New64() & kIdTiers[tier_]);
+ if (id == kAbsent) ++id;
+ if (id == kReserved) --id;
+ return id;
+ }
+
+ Env* env_;
+ SqliteStatement inserter_;
+ int tier_ = 0;
};
class GraphSaver {
public:
- static Status SaveToRun(Env* env, Sqlite* db, GraphDef* graph, int64 run_id) {
- auto get = db->Prepare("SELECT graph_id FROM Runs WHERE run_id = ?");
- get.BindInt(1, run_id);
- bool is_done;
- TF_RETURN_IF_ERROR(get.Step(&is_done));
- int64 graph_id = is_done ? 0 : get.ColumnInt(0);
- if (graph_id == 0) {
- graph_id = MakeRandomId();
- // TODO(@jart): Check for ID collision.
- auto set = db->Prepare("UPDATE Runs SET graph_id = ? WHERE run_id = ?");
- set.BindInt(1, graph_id);
- set.BindInt(2, run_id);
- TF_RETURN_IF_ERROR(set.StepAndReset());
- }
- return Save(env, db, graph, graph_id);
- }
-
- static Status Save(Env* env, Sqlite* db, GraphDef* graph, int64 graph_id) {
- GraphSaver saver{env, db, graph, graph_id};
+ static Status Save(Env* env, Sqlite* db, IdAllocator* id_allocator,
+ GraphDef* graph, int64* graph_id) {
+ TF_RETURN_IF_ERROR(id_allocator->CreateNewId(graph_id));
+ GraphSaver saver{env, db, graph, *graph_id};
saver.MapNameToNodeId();
TF_RETURN_IF_ERROR(saver.SaveNodeInputs());
TF_RETURN_IF_ERROR(saver.SaveNodes());
@@ -202,9 +220,6 @@ class GraphSaver {
}
Status SaveNodeInputs() {
- auto purge = db_->Prepare("DELETE FROM NodeInputs WHERE graph_id = ?");
- purge.BindInt(1, graph_id_);
- TF_RETURN_IF_ERROR(purge.StepAndReset());
auto insert = db_->Prepare(R"sql(
INSERT INTO NodeInputs (graph_id, node_id, idx, input_node_id, is_control)
VALUES (?, ?, ?, ?, ?)
@@ -233,9 +248,6 @@ class GraphSaver {
}
Status SaveNodes() {
- auto purge = db_->Prepare("DELETE FROM Nodes WHERE graph_id = ?");
- purge.BindInt(1, graph_id_);
- TF_RETURN_IF_ERROR(purge.StepAndReset());
auto insert = db_->Prepare(R"sql(
INSERT INTO Nodes (graph_id, node_id, node_name, op, device, node_def)
VALUES (?, ?, ?, ?, ?, ?)
@@ -263,7 +275,7 @@ class GraphSaver {
Status SaveGraph() {
auto insert = db_->Prepare(R"sql(
- INSERT OR REPLACE INTO Graphs (graph_id, inserted_time, graph_def)
+ INSERT INTO Graphs (graph_id, inserted_time, graph_def)
VALUES (?, ?, ?)
)sql");
insert.BindInt(1, graph_id_);
@@ -281,62 +293,258 @@ class GraphSaver {
std::unordered_map<StringPiece, int64, StringPieceHasher> name_to_node_id_;
};
-class SummaryDbWriter : public SummaryWriterInterface {
+class RunWriter {
public:
- SummaryDbWriter(Env* env, std::shared_ptr<Sqlite> db)
- : SummaryWriterInterface(),
- env_(env),
- db_(std::move(db)),
- txn_(db_),
- run_id_{0LL} {}
- ~SummaryDbWriter() override {}
+ RunWriter(Env* env, std::shared_ptr<Sqlite> db, const string& experiment_name,
+ const string& run_name, const string& user_name)
+ : env_{env},
+ db_{std::move(db)},
+ id_allocator_{env_, db_.get()},
+ experiment_name_{experiment_name},
+ run_name_{run_name},
+ user_name_{user_name},
+ insert_tensor_{db_->Prepare(R"sql(
+ INSERT OR REPLACE INTO Tensors (tag_id, step, computed_time, tensor)
+ VALUES (?, ?, ?, ?)
+ )sql")} {}
+
+ ~RunWriter() {
+ if (run_id_ == kAbsent) return;
+ auto update = db_->Prepare(R"sql(
+ UPDATE Runs SET finished_time = ? WHERE run_id = ?
+ )sql");
+ update.BindDouble(1, GetWallTime(env_));
+ update.BindInt(2, run_id_);
+ Status s = update.StepAndReset();
+ if (!s.ok()) {
+ LOG(ERROR) << "Failed to set Runs[" << run_id_
+ << "].finish_time: " << s.ToString();
+ }
+ }
- Status Initialize(const string& experiment_name, const string& run_name,
- const string& user_name) {
- mutex_lock ml(mu_);
- insert_tensor_ = db_->Prepare(R"sql(
- INSERT OR REPLACE INTO Tensors (tag_id, step, computed_time, tensor)
- VALUES (?, ?, ?, ?)
+ Status InsertTensor(int64 tag_id, int64 step, double computed_time,
+ Tensor t) {
+ insert_tensor_.BindInt(1, tag_id);
+ insert_tensor_.BindInt(2, step);
+ insert_tensor_.BindDouble(3, computed_time);
+ if (t.shape().dims() == 0 && t.dtype() == DT_INT64) {
+ insert_tensor_.BindInt(4, t.scalar<int64>()());
+ } else if (t.shape().dims() == 0 && t.dtype() == DT_DOUBLE) {
+ insert_tensor_.BindDouble(4, t.scalar<double>()());
+ } else {
+ TF_RETURN_IF_ERROR(BindTensor(&insert_tensor_, 4, t));
+ }
+ return insert_tensor_.StepAndReset();
+ }
+
+ Status InsertGraph(std::unique_ptr<GraphDef> g, double computed_time) {
+ TF_RETURN_IF_ERROR(InitializeRun(computed_time));
+ int64 graph_id;
+ TF_RETURN_IF_ERROR(
+ GraphSaver::Save(env_, db_.get(), &id_allocator_, g.get(), &graph_id));
+ if (run_id_ != kAbsent) {
+ auto set = db_->Prepare("UPDATE Runs SET graph_id = ? WHERE run_id = ?");
+ set.BindInt(1, graph_id);
+ set.BindInt(2, run_id_);
+ TF_RETURN_IF_ERROR(set.StepAndReset());
+ }
+ return Status::OK();
+ }
+
+ Status GetTagId(double computed_time, const string& tag_name,
+ const SummaryMetadata& metadata, int64* tag_id) {
+ TF_RETURN_IF_ERROR(InitializeRun(computed_time));
+ auto e = tag_ids_.find(tag_name);
+ if (e != tag_ids_.end()) {
+ *tag_id = e->second;
+ return Status::OK();
+ }
+ TF_RETURN_IF_ERROR(id_allocator_.CreateNewId(tag_id));
+ tag_ids_[tag_name] = *tag_id;
+ if (!metadata.summary_description().empty()) {
+ SqliteStatement insert_description = db_->Prepare(R"sql(
+ INSERT INTO Descriptions (id, description) VALUES (?, ?)
+ )sql");
+ insert_description.BindInt(1, *tag_id);
+ insert_description.BindText(2, metadata.summary_description());
+ TF_RETURN_IF_ERROR(insert_description.StepAndReset());
+ }
+ SqliteStatement insert = db_->Prepare(R"sql(
+ INSERT INTO Tags (
+ run_id,
+ tag_id,
+ tag_name,
+ inserted_time,
+ display_name,
+ plugin_name,
+ plugin_data
+ ) VALUES (?, ?, ?, ?, ?, ?, ?)
)sql");
- update_metadata_ = db_->Prepare(R"sql(
- UPDATE Tags SET metadata = ? WHERE tag_id = ?
+ if (run_id_ != kAbsent) insert.BindInt(1, run_id_);
+ insert.BindInt(2, *tag_id);
+ insert.BindText(3, tag_name);
+ insert.BindDouble(4, GetWallTime(env_));
+ if (!metadata.display_name().empty()) {
+ insert.BindText(5, metadata.display_name());
+ }
+ if (!metadata.plugin_data().plugin_name().empty()) {
+ insert.BindText(6, metadata.plugin_data().plugin_name());
+ }
+ if (!metadata.plugin_data().content().empty()) {
+ insert.BindBlob(7, metadata.plugin_data().content());
+ }
+ return insert.StepAndReset();
+ }
+
+ private:
+ Status InitializeUser() {
+ if (user_id_ != kAbsent || user_name_.empty()) return Status::OK();
+ SqliteStatement get = db_->Prepare(R"sql(
+ SELECT user_id FROM Users WHERE user_name = ?
)sql");
- experiment_name_ = experiment_name;
- run_name_ = run_name;
- user_name_ = user_name;
+ get.BindText(1, user_name_);
+ bool is_done;
+ TF_RETURN_IF_ERROR(get.Step(&is_done));
+ if (!is_done) {
+ user_id_ = get.ColumnInt(0);
+ return Status::OK();
+ }
+ TF_RETURN_IF_ERROR(id_allocator_.CreateNewId(&user_id_));
+ SqliteStatement insert = db_->Prepare(R"sql(
+ INSERT INTO Users (user_id, user_name, inserted_time) VALUES (?, ?, ?)
+ )sql");
+ insert.BindInt(1, user_id_);
+ insert.BindText(2, user_name_);
+ insert.BindDouble(3, GetWallTime(env_));
+ TF_RETURN_IF_ERROR(insert.StepAndReset());
+ return Status::OK();
+ }
+
+ Status InitializeExperiment(double computed_time) {
+ if (experiment_name_.empty()) return Status::OK();
+ if (experiment_id_ == kAbsent) {
+ TF_RETURN_IF_ERROR(InitializeUser());
+ SqliteStatement get = db_->Prepare(R"sql(
+ SELECT
+ experiment_id,
+ started_time
+ FROM
+ Experiments
+ WHERE
+ user_id IS ?
+ AND experiment_name = ?
+ )sql");
+ if (user_id_ != kAbsent) get.BindInt(1, user_id_);
+ get.BindText(2, experiment_name_);
+ bool is_done;
+ TF_RETURN_IF_ERROR(get.Step(&is_done));
+ if (!is_done) {
+ experiment_id_ = get.ColumnInt(0);
+ experiment_started_time_ = get.ColumnInt(1);
+ } else {
+ TF_RETURN_IF_ERROR(id_allocator_.CreateNewId(&experiment_id_));
+ experiment_started_time_ = computed_time;
+ SqliteStatement insert = db_->Prepare(R"sql(
+ INSERT INTO Experiments (
+ user_id,
+ experiment_id,
+ experiment_name,
+ inserted_time,
+ started_time
+ ) VALUES (?, ?, ?, ?, ?)
+ )sql");
+ if (user_id_ != kAbsent) insert.BindInt(1, user_id_);
+ insert.BindInt(2, experiment_id_);
+ insert.BindText(3, experiment_name_);
+ insert.BindDouble(4, GetWallTime(env_));
+ insert.BindDouble(5, computed_time);
+ TF_RETURN_IF_ERROR(insert.StepAndReset());
+ }
+ }
+ if (computed_time < experiment_started_time_) {
+ experiment_started_time_ = computed_time;
+ SqliteStatement update = db_->Prepare(R"sql(
+ UPDATE Experiments SET started_time = ? WHERE experiment_id = ?
+ )sql");
+ update.BindDouble(1, computed_time);
+ update.BindInt(2, experiment_id_);
+ TF_RETURN_IF_ERROR(update.StepAndReset());
+ }
return Status::OK();
}
- // TODO(@jart): Use transactions that COMMIT on Flush()
- // TODO(@jart): Retry Commit() on SQLITE_BUSY with exponential back-off.
+ Status InitializeRun(double computed_time) {
+ if (run_name_.empty()) return Status::OK();
+ TF_RETURN_IF_ERROR(InitializeExperiment(computed_time));
+ if (run_id_ == kAbsent) {
+ TF_RETURN_IF_ERROR(id_allocator_.CreateNewId(&run_id_));
+ run_started_time_ = computed_time;
+ SqliteStatement insert = db_->Prepare(R"sql(
+ INSERT OR REPLACE INTO Runs (
+ experiment_id,
+ run_id,
+ run_name,
+ inserted_time,
+ started_time
+ ) VALUES (?, ?, ?, ?, ?)
+ )sql");
+ if (experiment_id_ != kAbsent) insert.BindInt(1, experiment_id_);
+ insert.BindInt(2, run_id_);
+ insert.BindText(3, run_name_);
+ insert.BindDouble(4, GetWallTime(env_));
+ insert.BindDouble(5, computed_time);
+ TF_RETURN_IF_ERROR(insert.StepAndReset());
+ }
+ if (computed_time < run_started_time_) {
+ run_started_time_ = computed_time;
+ SqliteStatement update = db_->Prepare(R"sql(
+ UPDATE Runs SET started_time = ? WHERE run_id = ?
+ )sql");
+ update.BindDouble(1, computed_time);
+ update.BindInt(2, run_id_);
+ TF_RETURN_IF_ERROR(update.StepAndReset());
+ }
+ return Status::OK();
+ }
+
+ Env* env_;
+ std::shared_ptr<Sqlite> db_;
+ IdAllocator id_allocator_;
+ const string experiment_name_;
+ const string run_name_;
+ const string user_name_;
+ int64 experiment_id_ = kAbsent;
+ int64 run_id_ = kAbsent;
+ int64 user_id_ = kAbsent;
+ std::unordered_map<string, int64> tag_ids_;
+ double experiment_started_time_ = 0.0;
+ double run_started_time_ = 0.0;
+ SqliteStatement insert_tensor_;
+};
+
+class SummaryDbWriter : public SummaryWriterInterface {
+ public:
+ SummaryDbWriter(Env* env, std::shared_ptr<Sqlite> db,
+ const string& experiment_name, const string& run_name,
+ const string& user_name)
+ : SummaryWriterInterface(),
+ env_{env},
+ run_writer_{env, std::move(db), experiment_name, run_name, user_name} {}
+ ~SummaryDbWriter() override {}
+
Status Flush() override { return Status::OK(); }
Status WriteTensor(int64 global_step, Tensor t, const string& tag,
const string& serialized_metadata) override {
mutex_lock ml(mu_);
- TF_RETURN_IF_ERROR(InitializeParents());
- // TODO(@jart): Memoize tag_id.
- int64 tag_id;
- TF_RETURN_IF_ERROR(GetTagId(run_id_, tag, &tag_id));
+ SummaryMetadata metadata;
if (!serialized_metadata.empty()) {
- // TODO(@jart): Only update metadata for first tensor.
- update_metadata_.BindBlobUnsafe(1, serialized_metadata);
- update_metadata_.BindInt(2, tag_id);
- TF_RETURN_IF_ERROR(update_metadata_.StepAndReset());
+ metadata.ParseFromString(serialized_metadata);
}
- // TODO(@jart): Lease blocks of rowids and *_ids to minimize fragmentation.
- // TODO(@jart): Check for random ID collisions without needing txn retry.
- insert_tensor_.BindInt(1, tag_id);
- insert_tensor_.BindInt(2, global_step);
- insert_tensor_.BindDouble(3, GetWallTime(env_));
- if (t.shape().dims() == 0 && t.dtype() == DT_INT64) {
- insert_tensor_.BindInt(4, t.scalar<int64>()());
- } else if (t.shape().dims() == 0 && t.dtype() == DT_DOUBLE) {
- insert_tensor_.BindDouble(4, t.scalar<double>()());
- } else {
- TF_RETURN_IF_ERROR(BindTensor(&insert_tensor_, 4, t));
- }
- return insert_tensor_.StepAndReset();
+ double now = GetWallTime(env_);
+ int64 tag_id;
+ TF_RETURN_IF_ERROR(run_writer_.GetTagId(now, tag, metadata, &tag_id));
+ return run_writer_.InsertTensor(tag_id, global_step, now, t);
}
Status WriteScalar(int64 global_step, Tensor t, const string& tag) override {
@@ -348,28 +556,26 @@ class SummaryDbWriter : public SummaryWriterInterface {
Status WriteGraph(int64 global_step, std::unique_ptr<GraphDef> g) override {
mutex_lock ml(mu_);
- TF_RETURN_IF_ERROR(InitializeParents());
- return txn_.Transact(GraphSaver::SaveToRun, env_, db_.get(), g.get(),
- run_id_);
+ return run_writer_.InsertGraph(std::move(g), GetWallTime(env_));
}
Status WriteEvent(std::unique_ptr<Event> e) override {
switch (e->what_case()) {
case Event::WhatCase::kSummary: {
mutex_lock ml(mu_);
- TF_RETURN_IF_ERROR(InitializeParents());
- const Summary& summary = e->summary();
- for (int i = 0; i < summary.value_size(); ++i) {
- TF_RETURN_IF_ERROR(WriteSummary(e.get(), summary.value(i)));
+ Status s;
+ for (const auto& value : e->summary().value()) {
+ s.Update(WriteSummary(e.get(), value));
}
- return Status::OK();
+ return s;
}
case Event::WhatCase::kGraphDef: {
+ mutex_lock ml(mu_);
std::unique_ptr<GraphDef> graph{new GraphDef};
if (!ParseProtoUnlimited(graph.get(), e->graph_def())) {
return errors::DataLoss("parse event.graph_def failed");
}
- return WriteGraph(e->step(), std::move(graph));
+ return run_writer_.InsertGraph(std::move(graph), e->wall_time());
}
default:
// TODO(@jart): Handle other stuff.
@@ -401,128 +607,26 @@ class SummaryDbWriter : public SummaryWriterInterface {
string DebugString() override { return "SummaryDbWriter"; }
private:
- Status InitializeParents() EXCLUSIVE_LOCKS_REQUIRED(mu_) {
- if (run_id_ > 0) {
- return Status::OK();
- }
- int64 user_id;
- TF_RETURN_IF_ERROR(GetUserId(user_name_, &user_id));
- int64 experiment_id;
- TF_RETURN_IF_ERROR(
- GetExperimentId(user_id, experiment_name_, &experiment_id));
- TF_RETURN_IF_ERROR(GetRunId(experiment_id, run_name_, &run_id_));
- return Status::OK();
- }
-
- Status GetUserId(const string& user_name, int64* user_id)
- EXCLUSIVE_LOCKS_REQUIRED(mu_) {
- if (user_name.empty()) {
- *user_id = 0LL;
- return Status::OK();
- }
- SqliteStatement get_user_id = db_->Prepare(R"sql(
- SELECT user_id FROM Users WHERE user_name = ?
- )sql");
- get_user_id.BindText(1, user_name);
- bool is_done;
- TF_RETURN_IF_ERROR(get_user_id.Step(&is_done));
- if (!is_done) {
- *user_id = get_user_id.ColumnInt(0);
- } else {
- *user_id = MakeRandomId();
- SqliteStatement insert_user = db_->Prepare(R"sql(
- INSERT INTO Users (user_id, user_name, inserted_time) VALUES (?, ?, ?)
- )sql");
- insert_user.BindInt(1, *user_id);
- insert_user.BindText(2, user_name);
- insert_user.BindDouble(3, GetWallTime(env_));
- TF_RETURN_IF_ERROR(insert_user.StepAndReset());
- }
- return Status::OK();
- }
-
- Status GetExperimentId(int64 user_id, const string& experiment_name,
- int64* experiment_id) EXCLUSIVE_LOCKS_REQUIRED(mu_) {
- // TODO(@jart): Compute started_time.
- return GetId("Experiments", "user_id", user_id, "experiment_name",
- experiment_name, "experiment_id", experiment_id);
- }
-
- Status GetRunId(int64 experiment_id, const string& run_name, int64* run_id)
- EXCLUSIVE_LOCKS_REQUIRED(mu_) {
- // TODO(@jart): Compute started_time.
- return GetId("Runs", "experiment_id", experiment_id, "run_name", run_name,
- "run_id", run_id);
- }
-
- Status GetTagId(int64 run_id, const string& tag_name, int64* tag_id)
- EXCLUSIVE_LOCKS_REQUIRED(mu_) {
- return GetId("Tags", "run_id", run_id, "tag_name", tag_name, "tag_id",
- tag_id);
- }
-
- Status GetId(const char* table, const char* parent_id_field, int64 parent_id,
- const char* name_field, const string& name, const char* id_field,
- int64* id) EXCLUSIVE_LOCKS_REQUIRED(mu_) {
- if (name.empty()) {
- *id = 0LL;
- return Status::OK();
- }
- SqliteStatement select = db_->Prepare(
- strings::Printf("SELECT %s FROM %s WHERE %s = ? AND %s = ?", id_field,
- table, parent_id_field, name_field));
- if (parent_id > 0) {
- select.BindInt(1, parent_id);
- }
- select.BindText(2, name);
- bool is_done;
- TF_RETURN_IF_ERROR(select.Step(&is_done));
- if (!is_done) {
- *id = select.ColumnInt(0);
- } else {
- *id = MakeRandomId();
- SqliteStatement insert = db_->Prepare(strings::Printf(
- "INSERT INTO %s (%s, %s, %s, inserted_time) VALUES (?, ?, ?, ?)",
- table, parent_id_field, id_field, name_field));
- if (parent_id > 0) {
- insert.BindInt(1, parent_id);
- }
- insert.BindInt(2, *id);
- insert.BindText(3, name);
- insert.BindDouble(4, GetWallTime(env_));
- TF_RETURN_IF_ERROR(insert.StepAndReset());
- }
- return Status::OK();
- }
-
Status WriteSummary(const Event* e, const Summary::Value& summary)
EXCLUSIVE_LOCKS_REQUIRED(mu_) {
- int64 tag_id;
- TF_RETURN_IF_ERROR(GetTagId(run_id_, summary.tag(), &tag_id));
- insert_tensor_.BindInt(1, tag_id);
- insert_tensor_.BindInt(2, e->step());
- insert_tensor_.BindDouble(3, e->wall_time());
switch (summary.value_case()) {
- case Summary::Value::ValueCase::kSimpleValue:
- insert_tensor_.BindDouble(4, summary.simple_value());
- break;
+ case Summary::Value::ValueCase::kSimpleValue: {
+ int64 tag_id;
+ TF_RETURN_IF_ERROR(run_writer_.GetTagId(e->wall_time(), summary.tag(),
+ summary.metadata(), &tag_id));
+ Tensor t{DT_DOUBLE, {}};
+ t.scalar<double>()() = summary.simple_value();
+ return run_writer_.InsertTensor(tag_id, e->step(), e->wall_time(), t);
+ }
default:
// TODO(@jart): Handle the rest.
return Status::OK();
}
- return insert_tensor_.StepAndReset();
}
mutex mu_;
Env* env_;
- std::shared_ptr<Sqlite> db_ GUARDED_BY(mu_);
- Transactor txn_ GUARDED_BY(mu_);
- SqliteStatement insert_tensor_ GUARDED_BY(mu_);
- SqliteStatement update_metadata_ GUARDED_BY(mu_);
- string user_name_ GUARDED_BY(mu_);
- string experiment_name_ GUARDED_BY(mu_);
- string run_name_ GUARDED_BY(mu_);
- int64 run_id_ GUARDED_BY(mu_);
+ RunWriter run_writer_ GUARDED_BY(mu_);
};
} // namespace
@@ -532,14 +636,8 @@ Status CreateSummaryDbWriter(std::shared_ptr<Sqlite> db,
const string& run_name, const string& user_name,
Env* env, SummaryWriterInterface** result) {
TF_RETURN_IF_ERROR(SetupTensorboardSqliteDb(db));
- SummaryDbWriter* w = new SummaryDbWriter(env, std::move(db));
- const Status s = w->Initialize(experiment_name, run_name, user_name);
- if (!s.ok()) {
- w->Unref();
- *result = nullptr;
- return s;
- }
- *result = w;
+ *result = new SummaryDbWriter(env, std::move(db), experiment_name, run_name,
+ user_name);
return Status::OK();
}
diff --git a/tensorflow/contrib/tensorboard/db/summary_db_writer_test.cc b/tensorflow/contrib/tensorboard/db/summary_db_writer_test.cc
index 625861fa6b..5ea844b668 100644
--- a/tensorflow/contrib/tensorboard/db/summary_db_writer_test.cc
+++ b/tensorflow/contrib/tensorboard/db/summary_db_writer_test.cc
@@ -101,6 +101,7 @@ TEST_F(SummaryDbWriterTest, NothingWritten_NoRowsCreated) {
TF_ASSERT_OK(writer_->Flush());
writer_->Unref();
writer_ = nullptr;
+ EXPECT_EQ(0LL, QueryInt("SELECT COUNT(*) FROM Ids"));
EXPECT_EQ(0LL, QueryInt("SELECT COUNT(*) FROM Users"));
EXPECT_EQ(0LL, QueryInt("SELECT COUNT(*) FROM Experiments"));
EXPECT_EQ(0LL, QueryInt("SELECT COUNT(*) FROM Runs"));
@@ -109,13 +110,24 @@ TEST_F(SummaryDbWriterTest, NothingWritten_NoRowsCreated) {
}
TEST_F(SummaryDbWriterTest, TensorsWritten_RowsGetInitialized) {
+ SummaryMetadata metadata;
+ metadata.set_display_name("display_name");
+ metadata.set_summary_description("description");
+ metadata.mutable_plugin_data()->set_plugin_name("plugin_name");
+ metadata.mutable_plugin_data()->set_content("plugin_data");
+ SummaryMetadata metadata_nope;
+ metadata_nope.set_display_name("nope");
+ metadata_nope.set_summary_description("nope");
+ metadata_nope.mutable_plugin_data()->set_plugin_name("nope");
+ metadata_nope.mutable_plugin_data()->set_content("nope");
TF_ASSERT_OK(CreateSummaryDbWriter(db_, "mad-science", "train", "jart", &env_,
&writer_));
env_.AdvanceByMillis(23);
TF_ASSERT_OK(writer_->WriteTensor(1, MakeScalarInt64(123LL), "taggy",
- "this-is-metaaa"));
+ metadata.SerializeAsString()));
env_.AdvanceByMillis(23);
- TF_ASSERT_OK(writer_->WriteTensor(2, MakeScalarInt64(314LL), "taggy", ""));
+ TF_ASSERT_OK(writer_->WriteTensor(2, MakeScalarInt64(314LL), "taggy",
+ metadata_nope.SerializeAsString()));
TF_ASSERT_OK(writer_->Flush());
ASSERT_EQ(1LL, QueryInt("SELECT COUNT(*) FROM Users"));
@@ -148,27 +160,28 @@ TEST_F(SummaryDbWriterTest, TensorsWritten_RowsGetInitialized) {
EXPECT_EQ(run_id, QueryInt("SELECT run_id FROM Tags"));
EXPECT_EQ("taggy", QueryString("SELECT tag_name FROM Tags"));
EXPECT_EQ(0.023, QueryDouble("SELECT inserted_time FROM Tags"));
- EXPECT_EQ("this-is-metaaa", QueryString("SELECT metadata FROM Tags"));
+
+ EXPECT_EQ("display_name", QueryString("SELECT display_name FROM Tags"));
+ EXPECT_EQ("plugin_name", QueryString("SELECT plugin_name FROM Tags"));
+ EXPECT_EQ("plugin_data", QueryString("SELECT plugin_data FROM Tags"));
+ EXPECT_EQ("description", QueryString("SELECT description FROM Descriptions"));
EXPECT_EQ(tag_id, QueryInt("SELECT tag_id FROM Tensors WHERE step = 1"));
EXPECT_EQ(0.023,
QueryDouble("SELECT computed_time FROM Tensors WHERE step = 1"));
- EXPECT_EQ("this-is-metaaa", QueryString("SELECT metadata FROM Tags"));
EXPECT_FALSE(
QueryString("SELECT tensor FROM Tensors WHERE step = 1").empty());
EXPECT_EQ(tag_id, QueryInt("SELECT tag_id FROM Tensors WHERE step = 2"));
EXPECT_EQ(0.046,
QueryDouble("SELECT computed_time FROM Tensors WHERE step = 2"));
- EXPECT_EQ("this-is-metaaa", QueryString("SELECT metadata FROM Tags"));
EXPECT_FALSE(
QueryString("SELECT tensor FROM Tensors WHERE step = 2").empty());
}
TEST_F(SummaryDbWriterTest, EmptyParentNames_NoParentsCreated) {
TF_ASSERT_OK(CreateSummaryDbWriter(db_, "", "", "", &env_, &writer_));
- TF_ASSERT_OK(writer_->WriteTensor(1, MakeScalarInt64(123LL), "taggy",
- "this-is-metaaa"));
+ TF_ASSERT_OK(writer_->WriteTensor(1, MakeScalarInt64(123LL), "taggy", ""));
TF_ASSERT_OK(writer_->Flush());
ASSERT_EQ(0LL, QueryInt("SELECT COUNT(*) FROM Users"));
ASSERT_EQ(0LL, QueryInt("SELECT COUNT(*) FROM Experiments"));
@@ -317,5 +330,39 @@ TEST_F(SummaryDbWriterTest, WriteScalarUint8_CoercesToInt64) {
ASSERT_EQ(254LL, QueryInt("SELECT tensor FROM Tensors"));
}
+TEST_F(SummaryDbWriterTest, UsesIdsTable) {
+ SummaryMetadata metadata;
+ TF_ASSERT_OK(CreateSummaryDbWriter(db_, "mad-science", "train", "jart", &env_,
+ &writer_));
+ env_.AdvanceByMillis(23);
+ TF_ASSERT_OK(writer_->WriteTensor(1, MakeScalarInt64(123LL), "taggy",
+ metadata.SerializeAsString()));
+ TF_ASSERT_OK(writer_->Flush());
+ ASSERT_EQ(4LL, QueryInt("SELECT COUNT(*) FROM Ids"));
+ EXPECT_EQ(4LL, QueryInt(strings::StrCat(
+ "SELECT COUNT(*) FROM Ids WHERE id IN (",
+ QueryInt("SELECT user_id FROM Users"), ", ",
+ QueryInt("SELECT experiment_id FROM Experiments"), ", ",
+ QueryInt("SELECT run_id FROM Runs"), ", ",
+ QueryInt("SELECT tag_id FROM Tags"), ")")));
+}
+
+TEST_F(SummaryDbWriterTest, SetsRunFinishedTime) {
+ SummaryMetadata metadata;
+ TF_ASSERT_OK(CreateSummaryDbWriter(db_, "mad-science", "train", "jart", &env_,
+ &writer_));
+ env_.AdvanceByMillis(23);
+ TF_ASSERT_OK(writer_->WriteTensor(1, MakeScalarInt64(123LL), "taggy",
+ metadata.SerializeAsString()));
+ TF_ASSERT_OK(writer_->Flush());
+ ASSERT_EQ(0.023, QueryDouble("SELECT started_time FROM Runs"));
+ ASSERT_EQ(0.0, QueryDouble("SELECT finished_time FROM Runs"));
+ env_.AdvanceByMillis(23);
+ writer_->Unref();
+ writer_ = nullptr;
+ ASSERT_EQ(0.023, QueryDouble("SELECT started_time FROM Runs"));
+ ASSERT_EQ(0.046, QueryDouble("SELECT finished_time FROM Runs"));
+}
+
} // namespace
} // namespace tensorflow