aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/contrib/tensorboard
diff options
context:
space:
mode:
authorGravatar Justine Tunney <jart@google.com>2018-01-11 16:08:50 -0800
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2018-01-11 16:12:46 -0800
commitfebdd26ae594133d24f82544706b1e012a5cf1ea (patch)
treedd325008019ab10ce35f98368bf392ce4a118ec9 /tensorflow/contrib/tensorboard
parentfc252eb976c98c95a625ea6e6a0486334d3c5b6e (diff)
Add reservoir sampling to DB summary writer
This thing is kind of cool. It's able to turn a 350mB event log into a 35mB SQLite file at 80mBps with one Macbook core. Best of all, this was accomplished using a normalized schema without the embedded protos. PiperOrigin-RevId: 181676380
Diffstat (limited to 'tensorflow/contrib/tensorboard')
-rw-r--r--tensorflow/contrib/tensorboard/db/BUILD6
-rw-r--r--tensorflow/contrib/tensorboard/db/schema.cc239
-rw-r--r--tensorflow/contrib/tensorboard/db/summary_db_writer.cc1206
-rw-r--r--tensorflow/contrib/tensorboard/db/summary_db_writer.h7
-rw-r--r--tensorflow/contrib/tensorboard/db/summary_db_writer_test.cc71
5 files changed, 1058 insertions, 471 deletions
diff --git a/tensorflow/contrib/tensorboard/db/BUILD b/tensorflow/contrib/tensorboard/db/BUILD
index 3a3402c59b..4c9cc4ccd6 100644
--- a/tensorflow/contrib/tensorboard/db/BUILD
+++ b/tensorflow/contrib/tensorboard/db/BUILD
@@ -5,12 +5,13 @@ package(default_visibility = ["//tensorflow:internal"])
licenses(["notice"]) # Apache 2.0
-load("//tensorflow:tensorflow.bzl", "tf_cc_test")
+load("//tensorflow:tensorflow.bzl", "tf_cc_test", "tf_copts")
cc_library(
name = "schema",
srcs = ["schema.cc"],
hdrs = ["schema.h"],
+ copts = tf_copts(),
deps = [
"//tensorflow/core:lib",
"//tensorflow/core/lib/db:sqlite",
@@ -32,8 +33,10 @@ cc_library(
name = "summary_db_writer",
srcs = ["summary_db_writer.cc"],
hdrs = ["summary_db_writer.h"],
+ copts = tf_copts(),
deps = [
":schema",
+ "//tensorflow/core:framework",
"//tensorflow/core:lib",
"//tensorflow/core:lib_internal",
"//tensorflow/core:protos_all_cc",
@@ -47,6 +50,7 @@ tf_cc_test(
size = "small",
srcs = ["summary_db_writer_test.cc"],
deps = [
+ ":schema",
":summary_db_writer",
"//tensorflow/core:lib",
"//tensorflow/core:protos_all_cc",
diff --git a/tensorflow/contrib/tensorboard/db/schema.cc b/tensorflow/contrib/tensorboard/db/schema.cc
index 2cd00876f8..6ccd386dc0 100644
--- a/tensorflow/contrib/tensorboard/db/schema.cc
+++ b/tensorflow/contrib/tensorboard/db/schema.cc
@@ -22,8 +22,7 @@ namespace {
Status Run(Sqlite* db, const char* sql) {
SqliteStatement stmt;
TF_RETURN_IF_ERROR(db->Prepare(sql, &stmt));
- TF_RETURN_IF_ERROR(stmt.StepAndReset());
- return Status::OK();
+ return stmt.StepAndReset();
}
} // namespace
@@ -38,37 +37,34 @@ Status SetupTensorboardSqliteDb(Sqlite* db) {
db->PrepareOrDie("PRAGMA user_version=0").StepAndResetOrDie();
Status s;
- // Creates Ids table.
+ // Ids identify resources.
//
- // This table must be used to randomly allocate Permanent IDs for
- // all top-level tables, in order to maintain an invariant where
- // foo_id != bar_id for all IDs of any two tables.
+ // This table can be used to efficiently generate Permanent IDs in
+ // conjunction with a random number generator. Unlike rowids these
+ // IDs safe to use in URLs and unique across tables.
//
- // A row should only be deleted from this table if it can be
- // guaranteed that it exists absolutely nowhere else in the entire
- // system.
+ // Within any given system, there can't be any foo_id == bar_id for
+ // all rows of any two (Foos, Bars) tables. A row should only be
+ // deleted from this table if there's a very high level of confidence
+ // it exists nowhere else in the system.
//
// Fields:
- // id: An ID that was allocated globally. This must be in the
- // range [1,2**47). 0 is assigned the same meaning as NULL and
- // shouldn't be stored; 2**63-1 is reserved for statically
- // allocating space in a page to UPDATE later; and all other
- // int64 values are reserved for future use.
+ // id: The system-wide ID. This must be in the range [1,2**47). 0
+ // is assigned the same meaning as NULL and shouldn't be stored
+ // and all other int64 values are reserved for future use. Please
+ // note that id is also the rowid.
s.Update(Run(db, R"sql(
CREATE TABLE IF NOT EXISTS Ids (
id INTEGER PRIMARY KEY
)
)sql"));
- // Creates Descriptions table.
- //
- // This table allows TensorBoard to associate Markdown text with any
- // object in the database that has a Permanent ID.
+ // Descriptions are Markdown text that can be associated with any
+ // resource that has a Permanent ID.
//
// Fields:
- // id: The Permanent ID of the associated object. This is also the
- // SQLite rowid.
- // description: Arbitrary Markdown text.
+ // id: The foo_id of the associated row in Foos.
+ // description: Arbitrary NUL-terminated Markdown text.
s.Update(Run(db, R"sql(
CREATE TABLE IF NOT EXISTS Descriptions (
id INTEGER PRIMARY KEY,
@@ -76,121 +72,136 @@ Status SetupTensorboardSqliteDb(Sqlite* db) {
)
)sql"));
- // Creates Tensors table.
+ // Tensors are 0..n-dimensional numbers or strings.
//
// Fields:
- // rowid: Ephemeral b-tree ID dictating locality.
- // tag_id: ID of associated Tag.
+ // rowid: Ephemeral b-tree ID.
+ // series: The Permanent ID of a different resource, e.g. tag_id. A
+ // tensor will be vacuumed if no series == foo_id exists for all
+ // rows of all Foos. When series is NULL this tensor may serve
+ // undefined purposes. This field should be set on placeholders.
+ // step: Arbitrary number to uniquely order tensors within series.
+ // The meaning of step is undefined when series is NULL. This may
+ // be set on placeholders to prepopulate index pages.
// computed_time: Float UNIX timestamp with microsecond precision.
// In the old summaries system that uses FileWriter, this is the
// wall time around when tf.Session.run finished. In the new
// summaries system, it is the wall time of when the tensor was
// computed. On systems with monotonic clocks, it is calculated
// by adding the monotonic run duration to Run.started_time.
- // This field is not indexed because, in practice, it should be
- // ordered the same or nearly the same as TensorIndex, so local
- // insertion sort might be more suitable.
- // step: User-supplied number, ordering this tensor in Tag.
- // If NULL then the Tag must have only one Tensor.
- // tensor: Can be an INTEGER (DT_INT64), FLOAT (DT_DOUBLE), or
- // BLOB. The structure of a BLOB is currently undefined, but in
- // essence it is a Snappy tf.TensorProto that spills over into
- // TensorChunks.
+ // dtype: The tensorflow::DataType ID. For example, DT_INT64 is 9.
+ // When NULL or 0 this must be treated as a placeholder row that
+ // does not officially exist.
+ // shape: A comma-delimited list of int64 >=0 values representing
+ // length of each dimension in the tensor. This must be a valid
+ // shape. That means no -1 values and, in the case of numeric
+ // tensors, length(data) == product(shape) * sizeof(dtype). Empty
+ // means this is a scalar a.k.a. 0-dimensional tensor.
+ // data: Little-endian raw tensor memory. If dtype is DT_STRING and
+ // shape is empty, the nullness of this field indicates whether or
+ // not it contains the tensor contents; otherwise TensorStrings
+ // must be queried. If dtype is NULL then ZEROBLOB can be used on
+ // this field to reserve row space to be updated later.
s.Update(Run(db, R"sql(
CREATE TABLE IF NOT EXISTS Tensors (
rowid INTEGER PRIMARY KEY,
- tag_id INTEGER NOT NULL,
- computed_time REAL,
+ series INTEGER,
step INTEGER,
- tensor BLOB
+ dtype INTEGER,
+ computed_time REAL,
+ shape TEXT,
+ data BLOB
)
)sql"));
- // Uniquely indexes (tag_id, step) on Tensors table.
s.Update(Run(db, R"sql(
- CREATE UNIQUE INDEX IF NOT EXISTS TensorIndex
- ON Tensors (tag_id, step)
+ CREATE UNIQUE INDEX IF NOT EXISTS
+ TensorSeriesStepIndex
+ ON
+ Tensors (series, step)
+ WHERE
+ series IS NOT NULL
+ AND step IS NOT NULL
)sql"));
- // Creates TensorChunks table.
+ // TensorStrings are the flat contents of 1..n dimensional DT_STRING
+ // Tensors.
//
- // This table can be used to split up a tensor across many rows,
- // which has the advantage of not slowing down table scans on the
- // main table, allowing asynchronous fetching, minimizing copying,
- // and preventing large buffers from being allocated.
+ // The number of rows associated with a Tensor must be equal to the
+ // product of its Tensors.shape.
//
// Fields:
- // rowid: Ephemeral b-tree ID dictating locality.
- // tag_id: ID of associated Tag.
- // step: Same as corresponding Tensors.step.
- // sequence: 1-indexed sequence number for ordering chunks. Please
- // note that the 0th index is Tensors.tensor.
- // chunk: Bytes of next chunk in tensor.
+ // rowid: Ephemeral b-tree ID.
+ // tensor_rowid: References Tensors.rowid.
+ // idx: Index in flattened tensor, starting at 0.
+ // data: The string value at a particular index. NUL characters are
+ // permitted.
s.Update(Run(db, R"sql(
- CREATE TABLE IF NOT EXISTS TensorChunks (
+ CREATE TABLE IF NOT EXISTS TensorStrings (
rowid INTEGER PRIMARY KEY,
- tag_id INTEGER NOT NULL,
- step INTEGER,
- sequence INTEGER,
- chunk BLOB
+ tensor_rowid INTEGER NOT NULL,
+ idx INTEGER NOT NULL,
+ data BLOB
)
)sql"));
- // Uniquely indexes (tag_id, step, sequence) on TensorChunks table.
s.Update(Run(db, R"sql(
- CREATE UNIQUE INDEX IF NOT EXISTS TensorChunkIndex
- ON TensorChunks (tag_id, step, sequence)
+ CREATE UNIQUE INDEX IF NOT EXISTS TensorStringIndex
+ ON TensorStrings (tensor_rowid, idx)
)sql"));
- // Creates Tags table.
+ // Tags are series of Tensors.
//
// Fields:
- // rowid: Ephemeral b-tree ID dictating locality.
+ // rowid: Ephemeral b-tree ID.
// tag_id: The Permanent ID of the Tag.
// run_id: Optional ID of associated Run.
- // tag_name: The tag field in summary.proto, unique across Run.
// inserted_time: Float UNIX timestamp with µs precision. This is
// always the wall time of when the row was inserted into the
// DB. It may be used as a hint for an archival job.
+ // tag_name: The tag field in summary.proto, unique across Run.
// display_name: Optional for GUI and defaults to tag_name.
// plugin_name: Arbitrary TensorBoard plugin name for dispatch.
// plugin_data: Arbitrary data that plugin wants.
+ //
+ // TODO(jart): Maybe there should be a Plugins table?
s.Update(Run(db, R"sql(
CREATE TABLE IF NOT EXISTS Tags (
rowid INTEGER PRIMARY KEY,
run_id INTEGER,
tag_id INTEGER NOT NULL,
- tag_name TEXT,
inserted_time DOUBLE,
+ tag_name TEXT,
display_name TEXT,
plugin_name TEXT,
plugin_data BLOB
)
)sql"));
- // Uniquely indexes tag_id on Tags table.
s.Update(Run(db, R"sql(
CREATE UNIQUE INDEX IF NOT EXISTS TagIdIndex
ON Tags (tag_id)
)sql"));
- // Uniquely indexes (run_id, tag_name) on Tags table.
s.Update(Run(db, R"sql(
- CREATE UNIQUE INDEX IF NOT EXISTS TagNameIndex
- ON Tags (run_id, tag_name)
- WHERE tag_name IS NOT NULL
+ CREATE UNIQUE INDEX IF NOT EXISTS
+ TagRunNameIndex
+ ON
+ Tags (run_id, tag_name)
+ WHERE
+ run_id IS NOT NULL
+ AND tag_name IS NOT NULL
)sql"));
- // Creates Runs table.
+ // Runs are groups of Tags.
//
- // This table stores information about Runs. Each row usually
- // represents a single attempt at training or testing a TensorFlow
- // model, with a given set of hyper-parameters, whose summaries are
- // written out to a single event logs directory with a monotonic step
- // counter.
+ // Each Run usually represents a single attempt at training or testing
+ // a TensorFlow model, with a given set of hyper-parameters, whose
+ // summaries are written out to a single event logs directory with a
+ // monotonic step counter.
//
// Fields:
- // rowid: Ephemeral b-tree ID dictating locality.
+ // rowid: Ephemeral b-tree ID.
// run_id: The Permanent ID of the Run. This has a 1:1 mapping
// with a SummaryWriter instance. If two writers spawn for a
// given (user_name, run_name, run_name) then each should
@@ -199,8 +210,8 @@ Status SetupTensorboardSqliteDb(Sqlite* db) {
// previous invocations will then enter limbo, where they may be
// accessible for certain operations, but should be garbage
// collected eventually.
- // experiment_id: Optional ID of associated Experiment.
// run_name: User-supplied string, unique across Experiment.
+ // experiment_id: Optional ID of associated Experiment.
// inserted_time: Float UNIX timestamp with µs precision. This is
// always the time the row was inserted into the database. It
// does not change.
@@ -215,40 +226,33 @@ Status SetupTensorboardSqliteDb(Sqlite* db) {
// SummaryWriter resource that created this run was destroyed.
// Once this value becomes non-NULL a Run and its Tags and
// Tensors should be regarded as immutable.
- // graph_id: ID of associated Graphs row.
s.Update(Run(db, R"sql(
CREATE TABLE IF NOT EXISTS Runs (
rowid INTEGER PRIMARY KEY,
experiment_id INTEGER,
run_id INTEGER NOT NULL,
- run_name TEXT,
inserted_time REAL,
started_time REAL,
finished_time REAL,
- graph_id INTEGER
+ run_name TEXT
)
)sql"));
- // Uniquely indexes run_id on Runs table.
s.Update(Run(db, R"sql(
CREATE UNIQUE INDEX IF NOT EXISTS RunIdIndex
ON Runs (run_id)
)sql"));
- // Uniquely indexes (experiment_id, run_name) on Runs table.
s.Update(Run(db, R"sql(
CREATE UNIQUE INDEX IF NOT EXISTS RunNameIndex
ON Runs (experiment_id, run_name)
WHERE run_name IS NOT NULL
)sql"));
- // Creates Experiments table.
- //
- // This table stores information about experiments, which are sets of
- // runs.
+ // Experiments are groups of Runs.
//
// Fields:
- // rowid: Ephemeral b-tree ID dictating locality.
+ // rowid: Ephemeral b-tree ID.
// user_id: Optional ID of associated User.
// experiment_id: The Permanent ID of the Experiment.
// experiment_name: User-supplied string, unique across User.
@@ -259,34 +263,39 @@ Status SetupTensorboardSqliteDb(Sqlite* db) {
// the MIN(experiment.started_time, run.started_time) of each
// Run added to the database, including Runs which have since
// been overwritten.
+ // is_watching: A boolean indicating if someone is actively
+ // looking at this Experiment in the TensorBoard GUI. Tensor
+ // writers that do reservoir sampling can query this value to
+ // decide if they want the "keep last" behavior. This improves
+ // the performance of long running training while allowing low
+ // latency feedback in TensorBoard.
s.Update(Run(db, R"sql(
CREATE TABLE IF NOT EXISTS Experiments (
rowid INTEGER PRIMARY KEY,
user_id INTEGER,
experiment_id INTEGER NOT NULL,
- experiment_name TEXT,
inserted_time REAL,
- started_time REAL
+ started_time REAL,
+ is_watching INTEGER,
+ experiment_name TEXT
)
)sql"));
- // Uniquely indexes experiment_id on Experiments table.
s.Update(Run(db, R"sql(
CREATE UNIQUE INDEX IF NOT EXISTS ExperimentIdIndex
ON Experiments (experiment_id)
)sql"));
- // Uniquely indexes (user_id, experiment_name) on Experiments table.
s.Update(Run(db, R"sql(
CREATE UNIQUE INDEX IF NOT EXISTS ExperimentNameIndex
ON Experiments (user_id, experiment_name)
WHERE experiment_name IS NOT NULL
)sql"));
- // Creates Users table.
+ // Users are people who love TensorBoard.
//
// Fields:
- // rowid: Ephemeral b-tree ID dictating locality.
+ // rowid: Ephemeral b-tree ID.
// user_id: The Permanent ID of the User.
// user_name: Unique user name.
// email: Optional unique email address.
@@ -297,61 +306,66 @@ Status SetupTensorboardSqliteDb(Sqlite* db) {
CREATE TABLE IF NOT EXISTS Users (
rowid INTEGER PRIMARY KEY,
user_id INTEGER NOT NULL,
+ inserted_time REAL,
user_name TEXT,
- email TEXT,
- inserted_time REAL
+ email TEXT
)
)sql"));
- // Uniquely indexes user_id on Users table.
s.Update(Run(db, R"sql(
CREATE UNIQUE INDEX IF NOT EXISTS UserIdIndex
ON Users (user_id)
)sql"));
- // Uniquely indexes user_name on Users table.
s.Update(Run(db, R"sql(
CREATE UNIQUE INDEX IF NOT EXISTS UserNameIndex
ON Users (user_name)
WHERE user_name IS NOT NULL
)sql"));
- // Uniquely indexes email on Users table.
s.Update(Run(db, R"sql(
CREATE UNIQUE INDEX IF NOT EXISTS UserEmailIndex
ON Users (email)
WHERE email IS NOT NULL
)sql"));
- // Creates Graphs table.
+ // Graphs define how Tensors flowed in Runs.
//
// Fields:
- // rowid: Ephemeral b-tree ID dictating locality.
+ // rowid: Ephemeral b-tree ID.
+ // run_id: The Permanent ID of the associated Run. Only one Graph
+ // can be associated with a Run.
// graph_id: The Permanent ID of the Graph.
// inserted_time: Float UNIX timestamp with µs precision. This is
// always the wall time of when the row was inserted into the
// DB. It may be used as a hint for an archival job.
- // node_def: Contains Snappy tf.GraphDef proto. All fields will be
- // cleared except those not expressed in SQL.
+ // node_def: Contains tf.GraphDef proto. All fields will be cleared
+ // except those not expressed in SQL.
s.Update(Run(db, R"sql(
CREATE TABLE IF NOT EXISTS Graphs (
rowid INTEGER PRIMARY KEY,
+ run_id INTEGER,
graph_id INTEGER NOT NULL,
inserted_time REAL,
graph_def BLOB
)
)sql"));
- // Uniquely indexes graph_id on Graphs table.
s.Update(Run(db, R"sql(
CREATE UNIQUE INDEX IF NOT EXISTS GraphIdIndex
ON Graphs (graph_id)
)sql"));
- // Creates Nodes table.
+ s.Update(Run(db, R"sql(
+ CREATE UNIQUE INDEX IF NOT EXISTS GraphRunIndex
+ ON Graphs (run_id)
+ WHERE run_id IS NOT NULL
+ )sql"));
+
+ // Nodes are the vertices in Graphs.
//
// Fields:
- // rowid: Ephemeral b-tree ID dictating locality.
+ // rowid: Ephemeral b-tree ID.
// graph_id: The Permanent ID of the associated Graph.
// node_id: ID for this node. This is more like a 0-index within
// the Graph. Please note indexes are allowed to be removed.
@@ -361,8 +375,10 @@ Status SetupTensorboardSqliteDb(Sqlite* db) {
// node_def.name proto field must not be cleared.
// op: Copied from tf.NodeDef proto.
// device: Copied from tf.NodeDef proto.
- // node_def: Contains Snappy tf.NodeDef proto. All fields will be
- // cleared except those not expressed in SQL.
+ // node_def: Contains tf.NodeDef proto. All fields will be cleared
+ // except those not expressed in SQL.
+ //
+ // TODO(jart): Make separate tables for op and device strings.
s.Update(Run(db, R"sql(
CREATE TABLE IF NOT EXISTS Nodes (
rowid INTEGER PRIMARY KEY,
@@ -375,32 +391,35 @@ Status SetupTensorboardSqliteDb(Sqlite* db) {
)
)sql"));
- // Uniquely indexes (graph_id, node_id) on Nodes table.
s.Update(Run(db, R"sql(
CREATE UNIQUE INDEX IF NOT EXISTS NodeIdIndex
ON Nodes (graph_id, node_id)
)sql"));
- // Uniquely indexes (graph_id, node_name) on Nodes table.
s.Update(Run(db, R"sql(
CREATE UNIQUE INDEX IF NOT EXISTS NodeNameIndex
ON Nodes (graph_id, node_name)
WHERE node_name IS NOT NULL
)sql"));
- // Creates NodeInputs table.
+ // NodeInputs are directed edges between Nodes in Graphs.
//
// Fields:
- // rowid: Ephemeral b-tree ID dictating locality.
+ // rowid: Ephemeral b-tree ID.
// graph_id: The Permanent ID of the associated Graph.
// node_id: Index of Node in question. This can be considered the
// 'to' vertex.
// idx: Used for ordering inputs on a given Node.
// input_node_id: Nodes.node_id of the corresponding input node.
// This can be considered the 'from' vertex.
+ // input_node_idx: Since a Node can output multiple Tensors, this
+ // is the integer index of which of those outputs is our input.
+ // NULL is treated as 0.
// is_control: If non-zero, indicates this input is a controlled
// dependency, which means this isn't an edge through which
// tensors flow. NULL means 0.
+ //
+ // TODO(jart): Rename to NodeEdges.
s.Update(Run(db, R"sql(
CREATE TABLE IF NOT EXISTS NodeInputs (
rowid INTEGER PRIMARY KEY,
@@ -408,11 +427,11 @@ Status SetupTensorboardSqliteDb(Sqlite* db) {
node_id INTEGER NOT NULL,
idx INTEGER NOT NULL,
input_node_id INTEGER NOT NULL,
+ input_node_idx INTEGER,
is_control INTEGER
)
)sql"));
- // Uniquely indexes (graph_id, node_id, idx) on NodeInputs table.
s.Update(Run(db, R"sql(
CREATE UNIQUE INDEX IF NOT EXISTS NodeInputsIndex
ON NodeInputs (graph_id, node_id, idx)
diff --git a/tensorflow/contrib/tensorboard/db/summary_db_writer.cc b/tensorflow/contrib/tensorboard/db/summary_db_writer.cc
index 44887930c1..889ac43415 100644
--- a/tensorflow/contrib/tensorboard/db/summary_db_writer.cc
+++ b/tensorflow/contrib/tensorboard/db/summary_db_writer.cc
@@ -14,17 +14,37 @@ limitations under the License.
==============================================================================*/
#include "tensorflow/contrib/tensorboard/db/summary_db_writer.h"
-#include "tensorflow/contrib/tensorboard/db/schema.h"
#include "tensorflow/core/framework/graph.pb.h"
#include "tensorflow/core/framework/node_def.pb.h"
+#include "tensorflow/core/framework/register_types.h"
#include "tensorflow/core/framework/summary.pb.h"
#include "tensorflow/core/lib/core/stringpiece.h"
#include "tensorflow/core/lib/db/sqlite.h"
#include "tensorflow/core/lib/random/random.h"
-#include "tensorflow/core/lib/strings/stringprintf.h"
-#include "tensorflow/core/platform/fingerprint.h"
#include "tensorflow/core/util/event.pb.h"
+// TODO(jart): Break this up into multiple files with excellent unit tests.
+// TODO(jart): Make decision to write in separate op.
+// TODO(jart): Add really good busy handling.
+
+// clang-format off
+#define CALL_SUPPORTED_TYPES(m) \
+ TF_CALL_string(m) \
+ TF_CALL_half(m) \
+ TF_CALL_float(m) \
+ TF_CALL_double(m) \
+ TF_CALL_complex64(m) \
+ TF_CALL_complex128(m) \
+ TF_CALL_int8(m) \
+ TF_CALL_int16(m) \
+ TF_CALL_int32(m) \
+ TF_CALL_int64(m) \
+ TF_CALL_uint8(m) \
+ TF_CALL_uint16(m) \
+ TF_CALL_uint32(m) \
+ TF_CALL_uint64(m)
+// clang-format on
+
namespace tensorflow {
namespace {
@@ -33,115 +53,145 @@ const uint64 kIdTiers[] = {
0x7fffffULL, // 23-bit (3 bytes on disk)
0x7fffffffULL, // 31-bit (4 bytes on disk)
0x7fffffffffffULL, // 47-bit (5 bytes on disk)
- // Remaining bits reserved for future use.
+ // remaining bits for future use
};
const int kMaxIdTier = sizeof(kIdTiers) / sizeof(uint64);
const int kIdCollisionDelayMicros = 10;
const int kMaxIdCollisions = 21; // sum(2**i*10µs for i in range(21))~=21s
const int64 kAbsent = 0LL;
-const int64 kReserved = 0x7fffffffffffffffLL;
-double GetWallTime(Env* env) {
+const char* kScalarPluginName = "scalars";
+const char* kImagePluginName = "images";
+const char* kAudioPluginName = "audio";
+const char* kHistogramPluginName = "histograms";
+
+const int kScalarSlots = 10000;
+const int kImageSlots = 10;
+const int kAudioSlots = 10;
+const int kHistogramSlots = 1;
+const int kTensorSlots = 10;
+
+const int64 kReserveMinBytes = 32;
+const double kReserveMultiplier = 1.5;
+
+// Flush is a misnomer because what we're actually doing is having lots
+// of commits inside any SqliteTransaction that writes potentially
+// hundreds of megs but doesn't need the transaction to maintain its
+// invariants. This ensures the WAL read penalty is small and might
+// allow writers in other processes a chance to schedule.
+const uint64 kFlushBytes = 1024 * 1024;
+
+double DoubleTime(uint64 micros) {
// TODO(@jart): Follow precise definitions for time laid out in schema.
// TODO(@jart): Use monotonic clock from gRPC codebase.
- return static_cast<double>(env->NowMicros()) / 1.0e6;
+ return static_cast<double>(micros) / 1.0e6;
}
-Status Serialize(const protobuf::MessageLite& proto, string* output) {
- output->clear();
- if (!proto.SerializeToString(output)) {
- return errors::DataLoss("SerializeToString failed");
+string StringifyShape(const TensorShape& shape) {
+ string result;
+ bool first = true;
+ for (const auto& dim : shape) {
+ if (first) {
+ first = false;
+ } else {
+ strings::StrAppend(&result, ",");
+ }
+ strings::StrAppend(&result, dim.size);
}
- return Status::OK();
+ return result;
}
-Status BindProto(SqliteStatement* stmt, int parameter,
- const protobuf::MessageLite& proto) {
- string serialized;
- TF_RETURN_IF_ERROR(Serialize(proto, &serialized));
- stmt->BindBlob(parameter, serialized);
+Status CheckSupportedType(const Tensor& t) {
+#define CASE(T) \
+ case DataTypeToEnum<T>::value: \
+ break;
+ switch (t.dtype()) {
+ CALL_SUPPORTED_TYPES(CASE)
+ default:
+ return errors::Unimplemented(DataTypeString(t.dtype()),
+ " tensors unsupported on platform");
+ }
return Status::OK();
+#undef CASE
}
-Status BindTensor(SqliteStatement* stmt, int parameter, const Tensor& t) {
- // TODO(@jart): Make portable between little and big endian systems.
- // TODO(@jart): Use TensorChunks with minimal copying for big tensors.
- // TODO(@jart): Add field to indicate encoding.
- TensorProto p;
- t.AsProtoTensorContent(&p);
- return BindProto(stmt, parameter, p);
-}
-
-// Tries to fudge shape and dtype to something with smaller storage.
-Status CoerceScalar(const Tensor& t, Tensor* out) {
+Tensor AsScalar(const Tensor& t) {
+ Tensor t2{t.dtype(), {}};
+#define CASE(T) \
+ case DataTypeToEnum<T>::value: \
+ t2.scalar<T>()() = t.flat<T>()(0); \
+ break;
switch (t.dtype()) {
- case DT_DOUBLE:
- *out = t;
- break;
- case DT_INT64:
- *out = t;
- break;
- case DT_FLOAT:
- *out = {DT_DOUBLE, {}};
- out->scalar<double>()() = t.scalar<float>()();
- break;
- case DT_HALF:
- *out = {DT_DOUBLE, {}};
- out->scalar<double>()() = static_cast<double>(t.scalar<Eigen::half>()());
- break;
- case DT_INT32:
- *out = {DT_INT64, {}};
- out->scalar<int64>()() = t.scalar<int32>()();
- break;
- case DT_INT16:
- *out = {DT_INT64, {}};
- out->scalar<int64>()() = t.scalar<int16>()();
- break;
- case DT_INT8:
- *out = {DT_INT64, {}};
- out->scalar<int64>()() = t.scalar<int8>()();
- break;
- case DT_UINT32:
- *out = {DT_INT64, {}};
- out->scalar<int64>()() = t.scalar<uint32>()();
- break;
- case DT_UINT16:
- *out = {DT_INT64, {}};
- out->scalar<int64>()() = t.scalar<uint16>()();
- break;
- case DT_UINT8:
- *out = {DT_INT64, {}};
- out->scalar<int64>()() = t.scalar<uint8>()();
- break;
+ CALL_SUPPORTED_TYPES(CASE)
default:
- return errors::Unimplemented("Scalar summary for dtype ",
- DataTypeString(t.dtype()),
- " is not supported.");
+ t2 = {DT_FLOAT, {}};
+ t2.scalar<float>()() = NAN;
+ break;
}
- return Status::OK();
+ return t2;
+#undef CASE
+}
+
+void PatchPluginName(SummaryMetadata* metadata, const char* name) {
+ if (metadata->plugin_data().plugin_name().empty()) {
+ metadata->mutable_plugin_data()->set_plugin_name(name);
+ }
+}
+
+int GetSlots(const Tensor& t, const SummaryMetadata& metadata) {
+ if (metadata.plugin_data().plugin_name() == kScalarPluginName) {
+ return kScalarSlots;
+ } else if (metadata.plugin_data().plugin_name() == kImagePluginName) {
+ return kImageSlots;
+ } else if (metadata.plugin_data().plugin_name() == kAudioPluginName) {
+ return kAudioSlots;
+ } else if (metadata.plugin_data().plugin_name() == kHistogramPluginName) {
+ return kHistogramSlots;
+ } else if (t.dims() == 0 && t.dtype() != DT_STRING) {
+ return kScalarSlots;
+ } else {
+ return kTensorSlots;
+ }
+}
+
+Status SetDescription(Sqlite* db, int64 id, const StringPiece& markdown) {
+ const char* sql = R"sql(
+ INSERT OR REPLACE INTO Descriptions (id, description) VALUES (?, ?)
+ )sql";
+ SqliteStatement insert_desc;
+ TF_RETURN_IF_ERROR(db->Prepare(sql, &insert_desc));
+ insert_desc.BindInt(1, id);
+ insert_desc.BindText(2, markdown);
+ return insert_desc.StepAndReset();
}
-/// \brief Generates unique IDs randomly in the [1,2**63-2] range.
+/// \brief Generates unique IDs randomly in the [1,2**63-1] range.
///
/// This class starts off generating IDs in the [1,2**23-1] range,
/// because it's human friendly and occupies 4 bytes max on disk with
/// SQLite's zigzag varint encoding. Then, each time a collision
/// happens, the random space is increased by 8 bits.
///
-/// This class uses exponential back-off so writes will slow down as
-/// the ID space becomes exhausted.
+/// This class uses exponential back-off so writes gradually slow down
+/// as IDs become exhausted but reads are still possible.
+///
+/// This class is thread safe.
class IdAllocator {
public:
- IdAllocator(Env* env, Sqlite* db)
- : env_{env},
- inserter_{db->PrepareOrDie("INSERT INTO Ids (id) VALUES (?)")} {}
+ IdAllocator(Env* env, Sqlite* db) : env_{env}, db_{db} {
+ DCHECK(env_ != nullptr);
+ DCHECK(db_ != nullptr);
+ }
- Status CreateNewId(int64* id) {
+ Status CreateNewId(int64* id) LOCKS_EXCLUDED(mu_) {
+ mutex_lock lock(mu_);
Status s;
+ SqliteStatement stmt;
+ TF_RETURN_IF_ERROR(db_->Prepare("INSERT INTO Ids (id) VALUES (?)", &stmt));
for (int i = 0; i < kMaxIdCollisions; ++i) {
int64 tid = MakeRandomId();
- inserter_.BindInt(1, tid);
- s = inserter_.StepAndReset();
+ stmt.BindInt(1, tid);
+ s = stmt.StepAndReset();
if (s.ok()) {
*id = tid;
break;
@@ -167,34 +217,38 @@ class IdAllocator {
}
private:
- int64 MakeRandomId() {
+ int64 MakeRandomId() EXCLUSIVE_LOCKS_REQUIRED(mu_) {
int64 id = static_cast<int64>(random::New64() & kIdTiers[tier_]);
if (id == kAbsent) ++id;
- if (id == kReserved) --id;
return id;
}
- Env* env_;
- SqliteStatement inserter_;
- int tier_ = 0;
+ mutex mu_;
+ Env* const env_;
+ Sqlite* const db_;
+ int tier_ GUARDED_BY(mu_) = 0;
+
+ TF_DISALLOW_COPY_AND_ASSIGN(IdAllocator);
};
-class GraphSaver {
+class GraphWriter {
public:
- static Status Save(Env* env, Sqlite* db, IdAllocator* id_allocator,
- GraphDef* graph, int64* graph_id) {
- TF_RETURN_IF_ERROR(id_allocator->CreateNewId(graph_id));
- GraphSaver saver{env, db, graph, *graph_id};
+ static Status Save(Sqlite* db, SqliteTransaction* txn, IdAllocator* ids,
+ GraphDef* graph, uint64 now, int64 run_id, int64* graph_id)
+ SQLITE_EXCLUSIVE_TRANSACTIONS_REQUIRED(*db) {
+ TF_RETURN_IF_ERROR(ids->CreateNewId(graph_id));
+ GraphWriter saver{db, txn, graph, now, *graph_id};
saver.MapNameToNodeId();
- TF_RETURN_IF_ERROR(saver.SaveNodeInputs());
- TF_RETURN_IF_ERROR(saver.SaveNodes());
- TF_RETURN_IF_ERROR(saver.SaveGraph());
+ TF_RETURN_WITH_CONTEXT_IF_ERROR(saver.SaveNodeInputs(), "SaveNodeInputs");
+ TF_RETURN_WITH_CONTEXT_IF_ERROR(saver.SaveNodes(), "SaveNodes");
+ TF_RETURN_WITH_CONTEXT_IF_ERROR(saver.SaveGraph(run_id), "SaveGraph");
return Status::OK();
}
private:
- GraphSaver(Env* env, Sqlite* db, GraphDef* graph, int64 graph_id)
- : env_(env), db_(db), graph_(graph), graph_id_(graph_id) {}
+ GraphWriter(Sqlite* db, SqliteTransaction* txn, GraphDef* graph, uint64 now,
+ int64 graph_id)
+ : db_(db), txn_(txn), graph_(graph), now_(now), graph_id_(graph_id) {}
void MapNameToNodeId() {
size_t toto = static_cast<size_t>(graph_->node_size());
@@ -209,161 +263,193 @@ class GraphSaver {
}
Status SaveNodeInputs() {
- auto insert = db_->PrepareOrDie(R"sql(
- INSERT INTO NodeInputs (graph_id, node_id, idx, input_node_id, is_control)
- VALUES (?, ?, ?, ?, ?)
- )sql");
+ const char* sql = R"sql(
+ INSERT INTO NodeInputs (
+ graph_id,
+ node_id,
+ idx,
+ input_node_id,
+ input_node_idx,
+ is_control
+ ) VALUES (?, ?, ?, ?, ?, ?)
+ )sql";
+ SqliteStatement insert;
+ TF_RETURN_IF_ERROR(db_->Prepare(sql, &insert));
for (int node_id = 0; node_id < graph_->node_size(); ++node_id) {
const NodeDef& node = graph_->node(node_id);
for (int idx = 0; idx < node.input_size(); ++idx) {
StringPiece name = node.input(idx);
- insert.BindInt(1, graph_id_);
- insert.BindInt(2, node_id);
- insert.BindInt(3, idx);
+ int64 input_node_id;
+ int64 input_node_idx = 0;
+ int64 is_control = 0;
+ size_t i = name.rfind(':');
+ if (i != StringPiece::npos) {
+ if (!strings::safe_strto64(name.substr(i + 1, name.size() - i - 1),
+ &input_node_idx)) {
+ return errors::DataLoss("Bad NodeDef.input: ", name);
+ }
+ name.remove_suffix(name.size() - i);
+ }
if (!name.empty() && name[0] == '^') {
name.remove_prefix(1);
- insert.BindInt(5, 1);
+ is_control = 1;
}
auto e = name_to_node_id_.find(name);
if (e == name_to_node_id_.end()) {
return errors::DataLoss("Could not find node: ", name);
}
- insert.BindInt(4, e->second);
+ input_node_id = e->second;
+ insert.BindInt(1, graph_id_);
+ insert.BindInt(2, node_id);
+ insert.BindInt(3, idx);
+ insert.BindInt(4, input_node_id);
+ insert.BindInt(5, input_node_idx);
+ insert.BindInt(6, is_control);
+ unflushed_bytes_ += insert.size();
TF_RETURN_WITH_CONTEXT_IF_ERROR(insert.StepAndReset(), node.name(),
" -> ", name);
+ TF_RETURN_IF_ERROR(MaybeFlush());
}
}
return Status::OK();
}
Status SaveNodes() {
- auto insert = db_->PrepareOrDie(R"sql(
- INSERT INTO Nodes (graph_id, node_id, node_name, op, device, node_def)
- VALUES (?, ?, ?, ?, ?, snap(?))
- )sql");
+ const char* sql = R"sql(
+ INSERT INTO Nodes (
+ graph_id,
+ node_id,
+ node_name,
+ op,
+ device,
+ node_def)
+ VALUES (?, ?, ?, ?, ?, ?)
+ )sql";
+ SqliteStatement insert;
+ TF_RETURN_IF_ERROR(db_->Prepare(sql, &insert));
for (int node_id = 0; node_id < graph_->node_size(); ++node_id) {
NodeDef* node = graph_->mutable_node(node_id);
insert.BindInt(1, graph_id_);
insert.BindInt(2, node_id);
insert.BindText(3, node->name());
+ insert.BindText(4, node->op());
+ insert.BindText(5, node->device());
node->clear_name();
- if (!node->op().empty()) {
- insert.BindText(4, node->op());
- node->clear_op();
- }
- if (!node->device().empty()) {
- insert.BindText(5, node->device());
- node->clear_device();
- }
+ node->clear_op();
+ node->clear_device();
node->clear_input();
- TF_RETURN_IF_ERROR(BindProto(&insert, 6, *node));
+ string node_def;
+ if (node->SerializeToString(&node_def)) {
+ insert.BindBlobUnsafe(6, node_def);
+ }
+ unflushed_bytes_ += insert.size();
TF_RETURN_WITH_CONTEXT_IF_ERROR(insert.StepAndReset(), node->name());
+ TF_RETURN_IF_ERROR(MaybeFlush());
}
return Status::OK();
}
- Status SaveGraph() {
- auto insert = db_->PrepareOrDie(R"sql(
- INSERT INTO Graphs (graph_id, inserted_time, graph_def)
- VALUES (?, ?, snap(?))
- )sql");
- insert.BindInt(1, graph_id_);
- insert.BindDouble(2, GetWallTime(env_));
+ Status SaveGraph(int64 run_id) {
+ const char* sql = R"sql(
+ INSERT OR REPLACE INTO Graphs (
+ run_id,
+ graph_id,
+ inserted_time,
+ graph_def
+ ) VALUES (?, ?, ?, ?)
+ )sql";
+ SqliteStatement insert;
+ TF_RETURN_IF_ERROR(db_->Prepare(sql, &insert));
+ if (run_id != kAbsent) insert.BindInt(1, run_id);
+ insert.BindInt(2, graph_id_);
+ insert.BindDouble(3, DoubleTime(now_));
graph_->clear_node();
- TF_RETURN_IF_ERROR(BindProto(&insert, 3, *graph_));
+ string graph_def;
+ if (graph_->SerializeToString(&graph_def)) {
+ insert.BindBlobUnsafe(4, graph_def);
+ }
return insert.StepAndReset();
}
- Env* env_;
- Sqlite* db_;
- GraphDef* graph_;
- int64 graph_id_;
+ Status MaybeFlush() {
+ if (unflushed_bytes_ >= kFlushBytes) {
+ TF_RETURN_WITH_CONTEXT_IF_ERROR(txn_->Commit(), "flushing ",
+ unflushed_bytes_, " bytes");
+ unflushed_bytes_ = 0;
+ }
+ return Status::OK();
+ }
+
+ Sqlite* const db_;
+ SqliteTransaction* const txn_;
+ uint64 unflushed_bytes_ = 0;
+ GraphDef* const graph_;
+ const uint64 now_;
+ const int64 graph_id_;
std::vector<string> name_copies_;
std::unordered_map<StringPiece, int64, StringPieceHasher> name_to_node_id_;
+
+ TF_DISALLOW_COPY_AND_ASSIGN(GraphWriter);
};
-class RunWriter {
+/// \brief Run metadata manager.
+///
+/// This class gives us Tag IDs we can pass to SeriesWriter. In order
+/// to do that, rows are created in the Ids, Tags, Runs, Experiments,
+/// and Users tables.
+///
+/// This class is thread safe.
+class RunMetadata {
public:
- RunWriter(Env* env, Sqlite* db, const string& experiment_name,
- const string& run_name, const string& user_name)
- : env_{env},
- db_{db},
- id_allocator_{env_, db_},
+ RunMetadata(IdAllocator* ids, const string& experiment_name,
+ const string& run_name, const string& user_name)
+ : ids_{ids},
experiment_name_{experiment_name},
run_name_{run_name},
- user_name_{user_name},
- insert_tensor_{db_->PrepareOrDie(R"sql(
- INSERT OR REPLACE INTO Tensors (tag_id, step, computed_time, tensor)
- VALUES (?, ?, ?, snap(?))
- )sql")} {
- db_->Ref();
+ user_name_{user_name} {
+ DCHECK(ids_ != nullptr);
}
- ~RunWriter() {
- if (run_id_ != kAbsent) {
- auto update = db_->PrepareOrDie(R"sql(
- UPDATE Runs SET finished_time = ? WHERE run_id = ?
- )sql");
- update.BindDouble(1, GetWallTime(env_));
- update.BindInt(2, run_id_);
- Status s = update.StepAndReset();
- if (!s.ok()) {
- LOG(ERROR) << "Failed to set Runs[" << run_id_
- << "].finish_time: " << s.ToString();
- }
- }
- db_->Unref();
- }
+ const string& experiment_name() { return experiment_name_; }
+ const string& run_name() { return run_name_; }
+ const string& user_name() { return user_name_; }
- Status InsertTensor(int64 tag_id, int64 step, double computed_time,
- Tensor t) {
- insert_tensor_.BindInt(1, tag_id);
- insert_tensor_.BindInt(2, step);
- insert_tensor_.BindDouble(3, computed_time);
- if (t.shape().dims() == 0 && t.dtype() == DT_INT64) {
- insert_tensor_.BindInt(4, t.scalar<int64>()());
- } else if (t.shape().dims() == 0 && t.dtype() == DT_DOUBLE) {
- insert_tensor_.BindDouble(4, t.scalar<double>()());
- } else {
- TF_RETURN_IF_ERROR(BindTensor(&insert_tensor_, 4, t));
- }
- return insert_tensor_.StepAndReset();
+ int64 run_id() LOCKS_EXCLUDED(mu_) {
+ mutex_lock lock(mu_);
+ return run_id_;
}
- Status InsertGraph(std::unique_ptr<GraphDef> g, double computed_time) {
- TF_RETURN_IF_ERROR(InitializeRun(computed_time));
+ Status SetGraph(Sqlite* db, uint64 now, double computed_time,
+ std::unique_ptr<GraphDef> g) SQLITE_TRANSACTIONS_EXCLUDED(*db)
+ LOCKS_EXCLUDED(mu_) {
+ int64 run_id;
+ {
+ mutex_lock lock(mu_);
+ TF_RETURN_IF_ERROR(InitializeRun(db, now, computed_time));
+ run_id = run_id_;
+ }
int64 graph_id;
+ SqliteTransaction txn(*db); // only to increase performance
TF_RETURN_IF_ERROR(
- GraphSaver::Save(env_, db_, &id_allocator_, g.get(), &graph_id));
- if (run_id_ != kAbsent) {
- auto set =
- db_->PrepareOrDie("UPDATE Runs SET graph_id = ? WHERE run_id = ?");
- set.BindInt(1, graph_id);
- set.BindInt(2, run_id_);
- TF_RETURN_IF_ERROR(set.StepAndReset());
- }
- return Status::OK();
+ GraphWriter::Save(db, &txn, ids_, g.get(), now, run_id, &graph_id));
+ return txn.Commit();
}
- Status GetTagId(double computed_time, const string& tag_name,
- const SummaryMetadata& metadata, int64* tag_id) {
- TF_RETURN_IF_ERROR(InitializeRun(computed_time));
+ Status GetTagId(Sqlite* db, uint64 now, double computed_time,
+ const string& tag_name, int64* tag_id,
+ const SummaryMetadata& metadata) LOCKS_EXCLUDED(mu_) {
+ mutex_lock lock(mu_);
+ TF_RETURN_IF_ERROR(InitializeRun(db, now, computed_time));
auto e = tag_ids_.find(tag_name);
if (e != tag_ids_.end()) {
*tag_id = e->second;
return Status::OK();
}
- TF_RETURN_IF_ERROR(id_allocator_.CreateNewId(tag_id));
+ TF_RETURN_IF_ERROR(ids_->CreateNewId(tag_id));
tag_ids_[tag_name] = *tag_id;
- if (!metadata.summary_description().empty()) {
- SqliteStatement insert_description = db_->PrepareOrDie(R"sql(
- INSERT INTO Descriptions (id, description) VALUES (?, ?)
- )sql");
- insert_description.BindInt(1, *tag_id);
- insert_description.BindText(2, metadata.summary_description());
- TF_RETURN_IF_ERROR(insert_description.StepAndReset());
- }
- SqliteStatement insert = db_->PrepareOrDie(R"sql(
+ TF_RETURN_IF_ERROR(
+ SetDescription(db, *tag_id, metadata.summary_description()));
+ const char* sql = R"sql(
INSERT INTO Tags (
run_id,
tag_id,
@@ -372,30 +458,54 @@ class RunWriter {
display_name,
plugin_name,
plugin_data
- ) VALUES (?, ?, ?, ?, ?, ?, ?)
- )sql");
- if (run_id_ != kAbsent) insert.BindInt(1, run_id_);
- insert.BindInt(2, *tag_id);
- insert.BindText(3, tag_name);
- insert.BindDouble(4, GetWallTime(env_));
- if (!metadata.display_name().empty()) {
- insert.BindText(5, metadata.display_name());
- }
- if (!metadata.plugin_data().plugin_name().empty()) {
- insert.BindText(6, metadata.plugin_data().plugin_name());
- }
- if (!metadata.plugin_data().content().empty()) {
- insert.BindBlob(7, metadata.plugin_data().content());
- }
+ ) VALUES (
+ :run_id,
+ :tag_id,
+ :tag_name,
+ :inserted_time,
+ :display_name,
+ :plugin_name,
+ :plugin_data
+ )
+ )sql";
+ SqliteStatement insert;
+ TF_RETURN_IF_ERROR(db->Prepare(sql, &insert));
+ if (run_id_ != kAbsent) insert.BindInt(":run_id", run_id_);
+ insert.BindInt(":tag_id", *tag_id);
+ insert.BindTextUnsafe(":tag_name", tag_name);
+ insert.BindDouble(":inserted_time", DoubleTime(now));
+ insert.BindTextUnsafe(":display_name", metadata.display_name());
+ insert.BindTextUnsafe(":plugin_name", metadata.plugin_data().plugin_name());
+ insert.BindBlobUnsafe(":plugin_data", metadata.plugin_data().content());
return insert.StepAndReset();
}
+ Status GetIsWatching(Sqlite* db, bool* is_watching)
+ SQLITE_TRANSACTIONS_EXCLUDED(*db) LOCKS_EXCLUDED(mu_) {
+ mutex_lock lock(mu_);
+ if (experiment_id_ == kAbsent) {
+ *is_watching = true;
+ return Status::OK();
+ }
+ const char* sql = R"sql(
+ SELECT is_watching FROM Experiments WHERE experiment_id = ?
+ )sql";
+ SqliteStatement stmt;
+ TF_RETURN_IF_ERROR(db->Prepare(sql, &stmt));
+ stmt.BindInt(1, experiment_id_);
+ TF_RETURN_IF_ERROR(stmt.StepOnce());
+ *is_watching = stmt.ColumnInt(0) != 0;
+ return Status::OK();
+ }
+
private:
- Status InitializeUser() {
+ Status InitializeUser(Sqlite* db, uint64 now) EXCLUSIVE_LOCKS_REQUIRED(mu_) {
if (user_id_ != kAbsent || user_name_.empty()) return Status::OK();
- SqliteStatement get = db_->PrepareOrDie(R"sql(
+ const char* get_sql = R"sql(
SELECT user_id FROM Users WHERE user_name = ?
- )sql");
+ )sql";
+ SqliteStatement get;
+ TF_RETURN_IF_ERROR(db->Prepare(get_sql, &get));
get.BindText(1, user_name_);
bool is_done;
TF_RETURN_IF_ERROR(get.Step(&is_done));
@@ -403,22 +513,29 @@ class RunWriter {
user_id_ = get.ColumnInt(0);
return Status::OK();
}
- TF_RETURN_IF_ERROR(id_allocator_.CreateNewId(&user_id_));
- SqliteStatement insert = db_->PrepareOrDie(R"sql(
- INSERT INTO Users (user_id, user_name, inserted_time) VALUES (?, ?, ?)
- )sql");
+ TF_RETURN_IF_ERROR(ids_->CreateNewId(&user_id_));
+ const char* insert_sql = R"sql(
+ INSERT INTO Users (
+ user_id,
+ user_name,
+ inserted_time
+ ) VALUES (?, ?, ?)
+ )sql";
+ SqliteStatement insert;
+ TF_RETURN_IF_ERROR(db->Prepare(insert_sql, &insert));
insert.BindInt(1, user_id_);
insert.BindText(2, user_name_);
- insert.BindDouble(3, GetWallTime(env_));
+ insert.BindDouble(3, DoubleTime(now));
TF_RETURN_IF_ERROR(insert.StepAndReset());
return Status::OK();
}
- Status InitializeExperiment(double computed_time) {
+ Status InitializeExperiment(Sqlite* db, uint64 now, double computed_time)
+ EXCLUSIVE_LOCKS_REQUIRED(mu_) {
if (experiment_name_.empty()) return Status::OK();
if (experiment_id_ == kAbsent) {
- TF_RETURN_IF_ERROR(InitializeUser());
- SqliteStatement get = db_->PrepareOrDie(R"sql(
+ TF_RETURN_IF_ERROR(InitializeUser(db, now));
+ const char* get_sql = R"sql(
SELECT
experiment_id,
started_time
@@ -427,7 +544,9 @@ class RunWriter {
WHERE
user_id IS ?
AND experiment_name = ?
- )sql");
+ )sql";
+ SqliteStatement get;
+ TF_RETURN_IF_ERROR(db->Prepare(get_sql, &get));
if (user_id_ != kAbsent) get.BindInt(1, user_id_);
get.BindText(2, experiment_name_);
bool is_done;
@@ -436,30 +555,41 @@ class RunWriter {
experiment_id_ = get.ColumnInt(0);
experiment_started_time_ = get.ColumnInt(1);
} else {
- TF_RETURN_IF_ERROR(id_allocator_.CreateNewId(&experiment_id_));
+ TF_RETURN_IF_ERROR(ids_->CreateNewId(&experiment_id_));
experiment_started_time_ = computed_time;
- SqliteStatement insert = db_->PrepareOrDie(R"sql(
+ const char* insert_sql = R"sql(
INSERT INTO Experiments (
user_id,
experiment_id,
experiment_name,
inserted_time,
- started_time
- ) VALUES (?, ?, ?, ?, ?)
- )sql");
+ started_time,
+ is_watching
+ ) VALUES (?, ?, ?, ?, ?, ?)
+ )sql";
+ SqliteStatement insert;
+ TF_RETURN_IF_ERROR(db->Prepare(insert_sql, &insert));
if (user_id_ != kAbsent) insert.BindInt(1, user_id_);
insert.BindInt(2, experiment_id_);
insert.BindText(3, experiment_name_);
- insert.BindDouble(4, GetWallTime(env_));
+ insert.BindDouble(4, DoubleTime(now));
insert.BindDouble(5, computed_time);
+ insert.BindInt(6, 0);
TF_RETURN_IF_ERROR(insert.StepAndReset());
}
}
if (computed_time < experiment_started_time_) {
experiment_started_time_ = computed_time;
- SqliteStatement update = db_->PrepareOrDie(R"sql(
- UPDATE Experiments SET started_time = ? WHERE experiment_id = ?
- )sql");
+ const char* update_sql = R"sql(
+ UPDATE
+ Experiments
+ SET
+ started_time = ?
+ WHERE
+ experiment_id = ?
+ )sql";
+ SqliteStatement update;
+ TF_RETURN_IF_ERROR(db->Prepare(update_sql, &update));
update.BindDouble(1, computed_time);
update.BindInt(2, experiment_id_);
TF_RETURN_IF_ERROR(update.StepAndReset());
@@ -467,13 +597,14 @@ class RunWriter {
return Status::OK();
}
- Status InitializeRun(double computed_time) {
+ Status InitializeRun(Sqlite* db, uint64 now, double computed_time)
+ EXCLUSIVE_LOCKS_REQUIRED(mu_) {
if (run_name_.empty()) return Status::OK();
- TF_RETURN_IF_ERROR(InitializeExperiment(computed_time));
+ TF_RETURN_IF_ERROR(InitializeExperiment(db, now, computed_time));
if (run_id_ == kAbsent) {
- TF_RETURN_IF_ERROR(id_allocator_.CreateNewId(&run_id_));
+ TF_RETURN_IF_ERROR(ids_->CreateNewId(&run_id_));
run_started_time_ = computed_time;
- SqliteStatement insert = db_->PrepareOrDie(R"sql(
+ const char* insert_sql = R"sql(
INSERT OR REPLACE INTO Runs (
experiment_id,
run_id,
@@ -481,19 +612,28 @@ class RunWriter {
inserted_time,
started_time
) VALUES (?, ?, ?, ?, ?)
- )sql");
+ )sql";
+ SqliteStatement insert;
+ TF_RETURN_IF_ERROR(db->Prepare(insert_sql, &insert));
if (experiment_id_ != kAbsent) insert.BindInt(1, experiment_id_);
insert.BindInt(2, run_id_);
insert.BindText(3, run_name_);
- insert.BindDouble(4, GetWallTime(env_));
+ insert.BindDouble(4, DoubleTime(now));
insert.BindDouble(5, computed_time);
TF_RETURN_IF_ERROR(insert.StepAndReset());
}
if (computed_time < run_started_time_) {
run_started_time_ = computed_time;
- SqliteStatement update = db_->PrepareOrDie(R"sql(
- UPDATE Runs SET started_time = ? WHERE run_id = ?
- )sql");
+ const char* update_sql = R"sql(
+ UPDATE
+ Runs
+ SET
+ started_time = ?
+ WHERE
+ run_id = ?
+ )sql";
+ SqliteStatement update;
+ TF_RETURN_IF_ERROR(db->Prepare(update_sql, &update));
update.BindDouble(1, computed_time);
update.BindInt(2, run_id_);
TF_RETURN_IF_ERROR(update.StepAndReset());
@@ -501,79 +641,400 @@ class RunWriter {
return Status::OK();
}
- Env* env_;
- Sqlite* db_;
- IdAllocator id_allocator_;
+ mutex mu_;
+ IdAllocator* const ids_;
const string experiment_name_;
const string run_name_;
const string user_name_;
- int64 experiment_id_ = kAbsent;
- int64 run_id_ = kAbsent;
- int64 user_id_ = kAbsent;
- std::unordered_map<string, int64> tag_ids_;
- double experiment_started_time_ = 0.0;
- double run_started_time_ = 0.0;
- SqliteStatement insert_tensor_;
+ int64 experiment_id_ GUARDED_BY(mu_) = kAbsent;
+ int64 run_id_ GUARDED_BY(mu_) = kAbsent;
+ int64 user_id_ GUARDED_BY(mu_) = kAbsent;
+ double experiment_started_time_ GUARDED_BY(mu_) = 0.0;
+ double run_started_time_ GUARDED_BY(mu_) = 0.0;
+ std::unordered_map<string, int64> tag_ids_ GUARDED_BY(mu_);
+
+ TF_DISALLOW_COPY_AND_ASSIGN(RunMetadata);
};
+/// \brief Tensor writer for a single series, e.g. Tag.
+///
+/// This class can be used to write an infinite stream of Tensors to the
+/// database in a fixed block of contiguous disk space. This is
+/// accomplished using Algorithm R reservoir sampling.
+///
+/// The reservoir consists of a fixed number of rows, which are inserted
+/// using ZEROBLOB upon receiving the first sample, which is used to
+/// predict how big the other ones are likely to be. This is done
+/// transactionally in a way that tries to be mindful of other processes
+/// that might be trying to access the same DB.
+///
+/// Once the reservoir fills up, rows are replaced at random, and writes
+/// gradually become no-ops. This allows long training to go fast
+/// without configuration. The exception is when someone is actually
+/// looking at TensorBoard. When that happens, the "keep last" behavior
+/// is turned on and Append() will always result in a write.
+///
+/// If no one is watching training, this class still holds on to the
+/// most recent "dangling" Tensor, so if Finish() is called, the most
+/// recent training state can be written to disk.
+///
+/// The randomly selected sampling points should be consistent across
+/// multiple instances.
+///
+/// This class is thread safe.
+class SeriesWriter {
+ public:
+ SeriesWriter(int64 series, int slots, RunMetadata* meta)
+ : series_{series},
+ slots_{slots},
+ meta_{meta},
+ rng_{std::mt19937_64::default_seed} {
+ DCHECK(series_ > 0);
+ DCHECK(slots_ > 0);
+ }
+
+ Status Append(Sqlite* db, int64 step, uint64 now, double computed_time,
+ Tensor t) SQLITE_TRANSACTIONS_EXCLUDED(*db)
+ LOCKS_EXCLUDED(mu_) {
+ mutex_lock lock(mu_);
+ if (rowids_.empty()) {
+ Status s = Reserve(db, t);
+ if (!s.ok()) {
+ rowids_.clear();
+ return s;
+ }
+ }
+ DCHECK(rowids_.size() == slots_);
+ int64 rowid;
+ size_t i = count_;
+ if (i < slots_) {
+ rowid = last_rowid_ = rowids_[i];
+ } else {
+ i = rng_() % (i + 1);
+ if (i < slots_) {
+ rowid = last_rowid_ = rowids_[i];
+ } else {
+ bool keep_last;
+ TF_RETURN_IF_ERROR(meta_->GetIsWatching(db, &keep_last));
+ if (!keep_last) {
+ ++count_;
+ dangling_tensor_.reset(new Tensor(std::move(t)));
+ dangling_step_ = step;
+ dangling_computed_time_ = computed_time;
+ return Status::OK();
+ }
+ rowid = last_rowid_;
+ }
+ }
+ Status s = Write(db, rowid, step, computed_time, t);
+ if (s.ok()) {
+ ++count_;
+ dangling_tensor_.reset();
+ }
+ return s;
+ }
+
+ Status Finish(Sqlite* db) SQLITE_TRANSACTIONS_EXCLUDED(*db)
+ LOCKS_EXCLUDED(mu_) {
+ mutex_lock lock(mu_);
+ // Short runs: Delete unused pre-allocated Tensors.
+ if (count_ < rowids_.size()) {
+ SqliteTransaction txn(*db);
+ const char* sql = R"sql(
+ DELETE FROM Tensors WHERE rowid = ?
+ )sql";
+ SqliteStatement deleter;
+ TF_RETURN_IF_ERROR(db->Prepare(sql, &deleter));
+ for (size_t i = count_; i < rowids_.size(); ++i) {
+ deleter.BindInt(1, rowids_[i]);
+ TF_RETURN_IF_ERROR(deleter.StepAndReset());
+ }
+ TF_RETURN_IF_ERROR(txn.Commit());
+ rowids_.clear();
+ }
+ // Long runs: Make last sample be the very most recent one.
+ if (dangling_tensor_) {
+ DCHECK(last_rowid_ != kAbsent);
+ TF_RETURN_IF_ERROR(Write(db, last_rowid_, dangling_step_,
+ dangling_computed_time_, *dangling_tensor_));
+ dangling_tensor_.reset();
+ }
+ return Status::OK();
+ }
+
+ private:
+ Status Write(Sqlite* db, int64 rowid, int64 step, double computed_time,
+ const Tensor& t) SQLITE_TRANSACTIONS_EXCLUDED(*db) {
+ if (t.dtype() == DT_STRING) {
+ if (t.dims() == 0) {
+ return Update(db, step, computed_time, t, t.scalar<string>()(), rowid);
+ } else {
+ SqliteTransaction txn(*db);
+ TF_RETURN_IF_ERROR(
+ Update(db, step, computed_time, t, StringPiece(), rowid));
+ TF_RETURN_IF_ERROR(UpdateNdString(db, t, rowid));
+ return txn.Commit();
+ }
+ } else {
+ return Update(db, step, computed_time, t, t.tensor_data(), rowid);
+ }
+ }
+
+ Status Update(Sqlite* db, int64 step, double computed_time, const Tensor& t,
+ const StringPiece& data, int64 rowid) {
+ // TODO(jart): How can we ensure reservoir fills on replace?
+ const char* sql = R"sql(
+ UPDATE OR REPLACE
+ Tensors
+ SET
+ step = ?,
+ computed_time = ?,
+ dtype = ?,
+ shape = ?,
+ data = ?
+ WHERE
+ rowid = ?
+ )sql";
+ SqliteStatement stmt;
+ TF_RETURN_IF_ERROR(db->Prepare(sql, &stmt));
+ stmt.BindInt(1, step);
+ stmt.BindDouble(2, computed_time);
+ stmt.BindInt(3, t.dtype());
+ stmt.BindText(4, StringifyShape(t.shape()));
+ stmt.BindBlobUnsafe(5, data);
+ stmt.BindInt(6, rowid);
+ TF_RETURN_IF_ERROR(stmt.StepAndReset());
+ return Status::OK();
+ }
+
+ Status UpdateNdString(Sqlite* db, const Tensor& t, int64 tensor_rowid)
+ SQLITE_EXCLUSIVE_TRANSACTIONS_REQUIRED(*db) {
+ DCHECK_EQ(t.dtype(), DT_STRING);
+ DCHECK_GT(t.dims(), 0);
+ const char* deleter_sql = R"sql(
+ DELETE FROM TensorStrings WHERE tensor_rowid = ?
+ )sql";
+ SqliteStatement deleter;
+ TF_RETURN_IF_ERROR(db->Prepare(deleter_sql, &deleter));
+ deleter.BindInt(1, tensor_rowid);
+ TF_RETURN_WITH_CONTEXT_IF_ERROR(deleter.StepAndReset(), tensor_rowid);
+ const char* inserter_sql = R"sql(
+ INSERT INTO TensorStrings (
+ tensor_rowid,
+ idx,
+ data
+ ) VALUES (?, ?, ?)
+ )sql";
+ SqliteStatement inserter;
+ TF_RETURN_IF_ERROR(db->Prepare(inserter_sql, &inserter));
+ auto flat = t.flat<string>();
+ for (int64 i = 0; i < flat.size(); ++i) {
+ inserter.BindInt(1, tensor_rowid);
+ inserter.BindInt(2, i);
+ inserter.BindBlobUnsafe(3, flat(i));
+ TF_RETURN_WITH_CONTEXT_IF_ERROR(inserter.StepAndReset(), "i=", i);
+ }
+ return Status::OK();
+ }
+
+ Status Reserve(Sqlite* db, const Tensor& t) SQLITE_TRANSACTIONS_EXCLUDED(*db)
+ EXCLUSIVE_LOCKS_REQUIRED(mu_) {
+ SqliteTransaction txn(*db); // only for performance
+ unflushed_bytes_ = 0;
+ if (t.dtype() == DT_STRING) {
+ if (t.dims() == 0) {
+ TF_RETURN_IF_ERROR(ReserveData(db, &txn, t.scalar<string>()().size()));
+ } else {
+ TF_RETURN_IF_ERROR(ReserveTensors(db, &txn, kReserveMinBytes));
+ }
+ } else {
+ TF_RETURN_IF_ERROR(ReserveData(db, &txn, t.tensor_data().size()));
+ }
+ return txn.Commit();
+ }
+
+ Status ReserveData(Sqlite* db, SqliteTransaction* txn, size_t size)
+ SQLITE_EXCLUSIVE_TRANSACTIONS_REQUIRED(*db)
+ EXCLUSIVE_LOCKS_REQUIRED(mu_) {
+ int64 space =
+ static_cast<int64>(static_cast<double>(size) * kReserveMultiplier);
+ if (space < kReserveMinBytes) space = kReserveMinBytes;
+ return ReserveTensors(db, txn, space);
+ }
+
+ Status ReserveTensors(Sqlite* db, SqliteTransaction* txn,
+ int64 reserved_bytes)
+ SQLITE_EXCLUSIVE_TRANSACTIONS_REQUIRED(*db)
+ EXCLUSIVE_LOCKS_REQUIRED(mu_) {
+ const char* sql = R"sql(
+ INSERT INTO Tensors (
+ series,
+ data
+ ) VALUES (?, ZEROBLOB(?))
+ )sql";
+ SqliteStatement insert;
+ TF_RETURN_IF_ERROR(db->Prepare(sql, &insert));
+ // TODO(jart): Maybe preallocate index pages by setting step. This
+ // is tricky because UPDATE OR REPLACE can have a side
+ // effect of deleting preallocated rows.
+ for (int64 i = 0; i < slots_; ++i) {
+ insert.BindInt(1, series_);
+ insert.BindInt(2, reserved_bytes);
+ TF_RETURN_WITH_CONTEXT_IF_ERROR(insert.StepAndReset(), "i=", i);
+ rowids_.push_back(db->last_insert_rowid());
+ unflushed_bytes_ += reserved_bytes;
+ TF_RETURN_IF_ERROR(MaybeFlush(db, txn));
+ }
+ return Status::OK();
+ }
+
+ Status MaybeFlush(Sqlite* db, SqliteTransaction* txn)
+ SQLITE_EXCLUSIVE_TRANSACTIONS_REQUIRED(*db)
+ EXCLUSIVE_LOCKS_REQUIRED(mu_) {
+ if (unflushed_bytes_ >= kFlushBytes) {
+ TF_RETURN_WITH_CONTEXT_IF_ERROR(txn->Commit(), "flushing ",
+ unflushed_bytes_, " bytes");
+ unflushed_bytes_ = 0;
+ }
+ return Status::OK();
+ }
+
+ mutex mu_;
+ const int64 series_;
+ const int slots_;
+ RunMetadata* const meta_;
+ std::mt19937_64 rng_ GUARDED_BY(mu_);
+ uint64 count_ GUARDED_BY(mu_) = 0;
+ int64 last_rowid_ GUARDED_BY(mu_) = kAbsent;
+ std::vector<int64> rowids_ GUARDED_BY(mu_);
+ uint64 unflushed_bytes_ GUARDED_BY(mu_) = 0;
+ std::unique_ptr<Tensor> dangling_tensor_ GUARDED_BY(mu_);
+ int64 dangling_step_ GUARDED_BY(mu_) = 0;
+ double dangling_computed_time_ GUARDED_BY(mu_) = 0.0;
+
+ TF_DISALLOW_COPY_AND_ASSIGN(SeriesWriter);
+};
+
+/// \brief Tensor writer for a single Run.
+///
+/// This class farms out tensors to SeriesWriter instances. It also
+/// keeps track of whether or not someone is watching the TensorBoard
+/// GUI, so it can avoid writes when possible.
+///
+/// This class is thread safe.
+class RunWriter {
+ public:
+ explicit RunWriter(RunMetadata* meta) : meta_{meta} {}
+
+ Status Append(Sqlite* db, int64 tag_id, int64 step, uint64 now,
+ double computed_time, Tensor t, int slots)
+ SQLITE_TRANSACTIONS_EXCLUDED(*db) LOCKS_EXCLUDED(mu_) {
+ SeriesWriter* writer = GetSeriesWriter(tag_id, slots);
+ return writer->Append(db, step, now, computed_time, std::move(t));
+ }
+
+ Status Finish(Sqlite* db) SQLITE_TRANSACTIONS_EXCLUDED(*db)
+ LOCKS_EXCLUDED(mu_) {
+ mutex_lock lock(mu_);
+ if (series_writers_.empty()) return Status::OK();
+ for (auto i = series_writers_.begin(); i != series_writers_.end(); ++i) {
+ if (!i->second) continue;
+ TF_RETURN_WITH_CONTEXT_IF_ERROR(i->second->Finish(db),
+ "finish tag_id=", i->first);
+ i->second.reset();
+ }
+ return Status::OK();
+ }
+
+ private:
+ SeriesWriter* GetSeriesWriter(int64 tag_id, int slots) LOCKS_EXCLUDED(mu_) {
+ mutex_lock sl(mu_);
+ auto spot = series_writers_.find(tag_id);
+ if (spot == series_writers_.end()) {
+ SeriesWriter* writer = new SeriesWriter(tag_id, slots, meta_);
+ series_writers_[tag_id].reset(writer);
+ return writer;
+ } else {
+ return spot->second.get();
+ }
+ }
+
+ mutex mu_;
+ RunMetadata* const meta_;
+ std::unordered_map<int64, std::unique_ptr<SeriesWriter>> series_writers_
+ GUARDED_BY(mu_);
+
+ TF_DISALLOW_COPY_AND_ASSIGN(RunWriter);
+};
+
+/// \brief SQLite implementation of SummaryWriterInterface.
+///
+/// This class is thread safe.
class SummaryDbWriter : public SummaryWriterInterface {
public:
- SummaryDbWriter(Env* env, Sqlite* db,
- const string& experiment_name, const string& run_name,
- const string& user_name)
- : env_{env},
- run_writer_{env, db, experiment_name, run_name, user_name} {}
- ~SummaryDbWriter() override {}
+ SummaryDbWriter(Env* env, Sqlite* db, const string& experiment_name,
+ const string& run_name, const string& user_name)
+ : SummaryWriterInterface(),
+ env_{env},
+ db_{db},
+ ids_{env_, db_},
+ meta_{&ids_, experiment_name, run_name, user_name},
+ run_{&meta_} {
+ DCHECK(env_ != nullptr);
+ db_->Ref();
+ }
+
+ ~SummaryDbWriter() override {
+ core::ScopedUnref unref(db_);
+ Status s = run_.Finish(db_);
+ if (!s.ok()) {
+ // TODO(jart): Retry on transient errors here.
+ LOG(ERROR) << s.ToString();
+ }
+ int64 run_id = meta_.run_id();
+ if (run_id == kAbsent) return;
+ const char* sql = R"sql(
+ UPDATE Runs SET finished_time = ? WHERE run_id = ?
+ )sql";
+ SqliteStatement update;
+ s = db_->Prepare(sql, &update);
+ if (s.ok()) {
+ update.BindDouble(1, DoubleTime(env_->NowMicros()));
+ update.BindInt(2, run_id);
+ s = update.StepAndReset();
+ }
+ if (!s.ok()) {
+ LOG(ERROR) << "Failed to set Runs[" << run_id
+ << "].finish_time: " << s.ToString();
+ }
+ }
Status Flush() override { return Status::OK(); }
Status WriteTensor(int64 global_step, Tensor t, const string& tag,
const string& serialized_metadata) override {
- mutex_lock ml(mu_);
+ TF_RETURN_IF_ERROR(CheckSupportedType(t));
SummaryMetadata metadata;
- if (!serialized_metadata.empty()) {
- metadata.ParseFromString(serialized_metadata);
+ if (!metadata.ParseFromString(serialized_metadata)) {
+ return errors::InvalidArgument("Bad serialized_metadata");
}
- double now = GetWallTime(env_);
- int64 tag_id;
- TF_RETURN_IF_ERROR(run_writer_.GetTagId(now, tag, metadata, &tag_id));
- return run_writer_.InsertTensor(tag_id, global_step, now, t);
+ return Write(global_step, t, tag, metadata);
}
Status WriteScalar(int64 global_step, Tensor t, const string& tag) override {
- Tensor t2;
- TF_RETURN_IF_ERROR(CoerceScalar(t, &t2));
- // TODO(jart): Generate scalars plugin metadata on this value.
- return WriteTensor(global_step, std::move(t2), tag, "");
+ TF_RETURN_IF_ERROR(CheckSupportedType(t));
+ SummaryMetadata metadata;
+ PatchPluginName(&metadata, kScalarPluginName);
+ return Write(global_step, AsScalar(t), tag, metadata);
}
Status WriteGraph(int64 global_step, std::unique_ptr<GraphDef> g) override {
- mutex_lock ml(mu_);
- return run_writer_.InsertGraph(std::move(g), GetWallTime(env_));
+ uint64 now = env_->NowMicros();
+ return meta_.SetGraph(db_, now, DoubleTime(now), std::move(g));
}
Status WriteEvent(std::unique_ptr<Event> e) override {
- switch (e->what_case()) {
- case Event::WhatCase::kSummary: {
- mutex_lock ml(mu_);
- Status s;
- for (const auto& value : e->summary().value()) {
- s.Update(WriteSummary(e.get(), value));
- }
- return s;
- }
- case Event::WhatCase::kGraphDef: {
- mutex_lock ml(mu_);
- std::unique_ptr<GraphDef> graph{new GraphDef};
- if (!ParseProtoUnlimited(graph.get(), e->graph_def())) {
- return errors::DataLoss("parse event.graph_def failed");
- }
- return run_writer_.InsertGraph(std::move(graph), e->wall_time());
- }
- default:
- // TODO(@jart): Handle other stuff.
- return Status::OK();
- }
+ return MigrateEvent(std::move(e));
}
Status WriteHistogram(int64 global_step, Tensor t,
@@ -600,26 +1061,165 @@ class SummaryDbWriter : public SummaryWriterInterface {
string DebugString() override { return "SummaryDbWriter"; }
private:
- Status WriteSummary(const Event* e, const Summary::Value& summary)
- EXCLUSIVE_LOCKS_REQUIRED(mu_) {
- switch (summary.value_case()) {
- case Summary::Value::ValueCase::kSimpleValue: {
- int64 tag_id;
- TF_RETURN_IF_ERROR(run_writer_.GetTagId(e->wall_time(), summary.tag(),
- summary.metadata(), &tag_id));
- Tensor t{DT_DOUBLE, {}};
- t.scalar<double>()() = summary.simple_value();
- return run_writer_.InsertTensor(tag_id, e->step(), e->wall_time(), t);
+ Status Write(int64 step, const Tensor& t, const string& tag,
+ const SummaryMetadata& metadata) {
+ uint64 now = env_->NowMicros();
+ double computed_time = DoubleTime(now);
+ int64 tag_id;
+ TF_RETURN_IF_ERROR(
+ meta_.GetTagId(db_, now, computed_time, tag, &tag_id, metadata));
+ TF_RETURN_WITH_CONTEXT_IF_ERROR(
+ run_.Append(db_, tag_id, step, now, computed_time, t,
+ GetSlots(t, metadata)),
+ meta_.user_name(), "/", meta_.experiment_name(), "/", meta_.run_name(),
+ "/", tag, "@", step);
+ return Status::OK();
+ }
+
+ Status MigrateEvent(std::unique_ptr<Event> e) {
+ switch (e->what_case()) {
+ case Event::WhatCase::kSummary: {
+ uint64 now = env_->NowMicros();
+ auto summaries = e->mutable_summary();
+ for (int i = 0; i < summaries->value_size(); ++i) {
+ Summary::Value* value = summaries->mutable_value(i);
+ TF_RETURN_WITH_CONTEXT_IF_ERROR(
+ MigrateSummary(e.get(), value, now), meta_.user_name(), "/",
+ meta_.experiment_name(), "/", meta_.run_name(), "/", value->tag(),
+ "@", e->step());
+ }
+ break;
}
+ case Event::WhatCase::kGraphDef:
+ TF_RETURN_WITH_CONTEXT_IF_ERROR(
+ MigrateGraph(e.get(), e->graph_def()), meta_.user_name(), "/",
+ meta_.experiment_name(), "/", meta_.run_name(), "/__graph__@",
+ e->step());
+ break;
default:
- // TODO(@jart): Handle the rest.
- return Status::OK();
+ // TODO(@jart): Handle other stuff.
+ break;
}
+ return Status::OK();
}
- mutex mu_;
- Env* env_;
- RunWriter run_writer_ GUARDED_BY(mu_);
+ Status MigrateGraph(const Event* e, const string& graph_def) {
+ uint64 now = env_->NowMicros();
+ std::unique_ptr<GraphDef> graph{new GraphDef};
+ if (!ParseProtoUnlimited(graph.get(), graph_def)) {
+ return errors::InvalidArgument("bad proto");
+ }
+ return meta_.SetGraph(db_, now, e->wall_time(), std::move(graph));
+ }
+
+ Status MigrateSummary(const Event* e, Summary::Value* s, uint64 now) {
+ switch (s->value_case()) {
+ case Summary::Value::ValueCase::kTensor:
+ TF_RETURN_WITH_CONTEXT_IF_ERROR(MigrateTensor(e, s, now), "tensor");
+ break;
+ case Summary::Value::ValueCase::kSimpleValue:
+ TF_RETURN_WITH_CONTEXT_IF_ERROR(MigrateScalar(e, s, now), "scalar");
+ break;
+ case Summary::Value::ValueCase::kHisto:
+ TF_RETURN_WITH_CONTEXT_IF_ERROR(MigrateHistogram(e, s, now), "histo");
+ break;
+ case Summary::Value::ValueCase::kImage:
+ TF_RETURN_WITH_CONTEXT_IF_ERROR(MigrateImage(e, s, now), "image");
+ break;
+ case Summary::Value::ValueCase::kAudio:
+ TF_RETURN_WITH_CONTEXT_IF_ERROR(MigrateAudio(e, s, now), "audio");
+ break;
+ default:
+ break;
+ }
+ return Status::OK();
+ }
+
+ Status MigrateTensor(const Event* e, Summary::Value* s, uint64 now) {
+ Tensor t;
+ if (!t.FromProto(s->tensor())) return errors::InvalidArgument("bad proto");
+ TF_RETURN_IF_ERROR(CheckSupportedType(t));
+ int64 tag_id;
+ TF_RETURN_IF_ERROR(meta_.GetTagId(db_, now, e->wall_time(), s->tag(),
+ &tag_id, s->metadata()));
+ return run_.Append(db_, tag_id, e->step(), now, e->wall_time(), t,
+ GetSlots(t, s->metadata()));
+ }
+
+ // TODO(jart): Refactor Summary -> Tensor logic into separate file.
+
+ Status MigrateScalar(const Event* e, Summary::Value* s, uint64 now) {
+ // See tensorboard/plugins/scalar/summary.py and data_compat.py
+ Tensor t{DT_FLOAT, {}};
+ t.scalar<float>()() = s->simple_value();
+ int64 tag_id;
+ PatchPluginName(s->mutable_metadata(), kScalarPluginName);
+ TF_RETURN_IF_ERROR(meta_.GetTagId(db_, now, e->wall_time(), s->tag(),
+ &tag_id, s->metadata()));
+ return run_.Append(db_, tag_id, e->step(), now, e->wall_time(),
+ std::move(t), kScalarSlots);
+ }
+
+ Status MigrateHistogram(const Event* e, Summary::Value* s, uint64 now) {
+ const HistogramProto& histo = s->histo();
+ int k = histo.bucket_size();
+ if (k != histo.bucket_limit_size()) {
+ return errors::InvalidArgument("size mismatch");
+ }
+ // See tensorboard/plugins/histogram/summary.py and data_compat.py
+ Tensor t{DT_DOUBLE, {k, 3}};
+ auto data = t.flat<double>();
+ for (int i = 0; i < k; ++i) {
+ double left_edge = ((i - 1 >= 0) ? histo.bucket_limit(i - 1)
+ : std::numeric_limits<double>::min());
+ double right_edge = ((i + 1 < k) ? histo.bucket_limit(i + 1)
+ : std::numeric_limits<double>::max());
+ data(i + 0) = left_edge;
+ data(i + 1) = right_edge;
+ data(i + 2) = histo.bucket(i);
+ }
+ int64 tag_id;
+ PatchPluginName(s->mutable_metadata(), kHistogramPluginName);
+ TF_RETURN_IF_ERROR(meta_.GetTagId(db_, now, e->wall_time(), s->tag(),
+ &tag_id, s->metadata()));
+ return run_.Append(db_, tag_id, e->step(), now, e->wall_time(),
+ std::move(t), kHistogramSlots);
+ }
+
+ Status MigrateImage(const Event* e, Summary::Value* s, uint64 now) {
+ // See tensorboard/plugins/image/summary.py and data_compat.py
+ Tensor t{DT_STRING, {3}};
+ auto img = s->mutable_image();
+ t.flat<string>()(0) = strings::StrCat(img->width());
+ t.flat<string>()(1) = strings::StrCat(img->height());
+ t.flat<string>()(2) = std::move(*img->mutable_encoded_image_string());
+ int64 tag_id;
+ PatchPluginName(s->mutable_metadata(), kImagePluginName);
+ TF_RETURN_IF_ERROR(meta_.GetTagId(db_, now, e->wall_time(), s->tag(),
+ &tag_id, s->metadata()));
+ return run_.Append(db_, tag_id, e->step(), now, e->wall_time(),
+ std::move(t), kImageSlots);
+ }
+
+ Status MigrateAudio(const Event* e, Summary::Value* s, uint64 now) {
+ // See tensorboard/plugins/audio/summary.py and data_compat.py
+ Tensor t{DT_STRING, {1, 2}};
+ auto wav = s->mutable_audio();
+ t.flat<string>()(0) = std::move(*wav->mutable_encoded_audio_string());
+ t.flat<string>()(1) = "";
+ int64 tag_id;
+ PatchPluginName(s->mutable_metadata(), kAudioPluginName);
+ TF_RETURN_IF_ERROR(meta_.GetTagId(db_, now, e->wall_time(), s->tag(),
+ &tag_id, s->metadata()));
+ return run_.Append(db_, tag_id, e->step(), now, e->wall_time(),
+ std::move(t), kAudioSlots);
+ }
+
+ Env* const env_;
+ Sqlite* const db_;
+ IdAllocator ids_;
+ RunMetadata meta_;
+ RunWriter run_;
};
} // namespace
@@ -627,8 +1227,6 @@ class SummaryDbWriter : public SummaryWriterInterface {
Status CreateSummaryDbWriter(Sqlite* db, const string& experiment_name,
const string& run_name, const string& user_name,
Env* env, SummaryWriterInterface** result) {
- *result = nullptr;
- TF_RETURN_IF_ERROR(SetupTensorboardSqliteDb(db));
*result = new SummaryDbWriter(env, db, experiment_name, run_name, user_name);
return Status::OK();
}
diff --git a/tensorflow/contrib/tensorboard/db/summary_db_writer.h b/tensorflow/contrib/tensorboard/db/summary_db_writer.h
index 5a3de195de..746da1533b 100644
--- a/tensorflow/contrib/tensorboard/db/summary_db_writer.h
+++ b/tensorflow/contrib/tensorboard/db/summary_db_writer.h
@@ -19,16 +19,15 @@ limitations under the License.
#include "tensorflow/core/lib/core/status.h"
#include "tensorflow/core/lib/db/sqlite.h"
#include "tensorflow/core/platform/env.h"
-#include "tensorflow/core/platform/types.h"
namespace tensorflow {
/// \brief Creates SQLite SummaryWriterInterface.
///
/// This can be used to write tensors from the execution graph directly
-/// to a database. The schema will be created automatically, but only
-/// if necessary. Entries in the Users, Experiments, and Runs tables
-/// will be created automatically if they don't already exist.
+/// to a database. The schema must be created beforehand. Entries in
+/// Users, Experiments, and Runs tables will be created automatically
+/// if they don't already exist.
///
/// Please note that the type signature of this function may change in
/// the future if support for other DBs is added to core.
diff --git a/tensorflow/contrib/tensorboard/db/summary_db_writer_test.cc b/tensorflow/contrib/tensorboard/db/summary_db_writer_test.cc
index 68444c35be..29b8063218 100644
--- a/tensorflow/contrib/tensorboard/db/summary_db_writer_test.cc
+++ b/tensorflow/contrib/tensorboard/db/summary_db_writer_test.cc
@@ -14,6 +14,8 @@ limitations under the License.
==============================================================================*/
#include "tensorflow/contrib/tensorboard/db/summary_db_writer.h"
+#include "tensorflow/contrib/tensorboard/db/schema.h"
+#include "tensorflow/core/framework/function.pb.h"
#include "tensorflow/core/framework/graph.pb.h"
#include "tensorflow/core/framework/node_def.pb.h"
#include "tensorflow/core/framework/summary.pb.h"
@@ -27,8 +29,6 @@ limitations under the License.
namespace tensorflow {
namespace {
-const float kTolerance = 1e-5;
-
Tensor MakeScalarInt64(int64 x) {
Tensor t(DT_INT64, TensorShape({}));
t.scalar<int64>()() = x;
@@ -50,6 +50,7 @@ class SummaryDbWriterTest : public ::testing::Test {
protected:
void SetUp() override {
TF_ASSERT_OK(Sqlite::Open(":memory:", SQLITE_OPEN_READWRITE, &db_));
+ TF_ASSERT_OK(SetupTensorboardSqliteDb(db_));
}
void TearDown() override {
@@ -138,7 +139,7 @@ TEST_F(SummaryDbWriterTest, TensorsWritten_RowsGetInitialized) {
ASSERT_EQ(1LL, QueryInt("SELECT COUNT(*) FROM Experiments"));
ASSERT_EQ(1LL, QueryInt("SELECT COUNT(*) FROM Runs"));
ASSERT_EQ(1LL, QueryInt("SELECT COUNT(*) FROM Tags"));
- ASSERT_EQ(2LL, QueryInt("SELECT COUNT(*) FROM Tensors"));
+ ASSERT_EQ(10000LL, QueryInt("SELECT COUNT(*) FROM Tensors"));
int64 user_id = QueryInt("SELECT user_id FROM Users");
int64 experiment_id = QueryInt("SELECT experiment_id FROM Experiments");
@@ -170,17 +171,13 @@ TEST_F(SummaryDbWriterTest, TensorsWritten_RowsGetInitialized) {
EXPECT_EQ("plugin_data", QueryString("SELECT plugin_data FROM Tags"));
EXPECT_EQ("description", QueryString("SELECT description FROM Descriptions"));
- EXPECT_EQ(tag_id, QueryInt("SELECT tag_id FROM Tensors WHERE step = 1"));
+ EXPECT_EQ(tag_id, QueryInt("SELECT series FROM Tensors WHERE step = 1"));
EXPECT_EQ(0.023,
QueryDouble("SELECT computed_time FROM Tensors WHERE step = 1"));
- EXPECT_FALSE(
- QueryString("SELECT tensor FROM Tensors WHERE step = 1").empty());
- EXPECT_EQ(tag_id, QueryInt("SELECT tag_id FROM Tensors WHERE step = 2"));
+ EXPECT_EQ(tag_id, QueryInt("SELECT series FROM Tensors WHERE step = 2"));
EXPECT_EQ(0.046,
QueryDouble("SELECT computed_time FROM Tensors WHERE step = 2"));
- EXPECT_FALSE(
- QueryString("SELECT tensor FROM Tensors WHERE step = 2").empty());
}
TEST_F(SummaryDbWriterTest, EmptyParentNames_NoParentsCreated) {
@@ -191,7 +188,7 @@ TEST_F(SummaryDbWriterTest, EmptyParentNames_NoParentsCreated) {
ASSERT_EQ(0LL, QueryInt("SELECT COUNT(*) FROM Experiments"));
ASSERT_EQ(0LL, QueryInt("SELECT COUNT(*) FROM Runs"));
ASSERT_EQ(1LL, QueryInt("SELECT COUNT(*) FROM Tags"));
- ASSERT_EQ(1LL, QueryInt("SELECT COUNT(*) FROM Tensors"));
+ ASSERT_EQ(10000LL, QueryInt("SELECT COUNT(*) FROM Tensors"));
}
TEST_F(SummaryDbWriterTest, WriteEvent_Scalar) {
@@ -208,33 +205,24 @@ TEST_F(SummaryDbWriterTest, WriteEvent_Scalar) {
TF_ASSERT_OK(writer_->WriteEvent(std::move(e)));
TF_ASSERT_OK(writer_->Flush());
ASSERT_EQ(2LL, QueryInt("SELECT COUNT(*) FROM Tags"));
- ASSERT_EQ(2LL, QueryInt("SELECT COUNT(*) FROM Tensors"));
+ ASSERT_EQ(20000LL, QueryInt("SELECT COUNT(*) FROM Tensors"));
int64 tag1_id = QueryInt("SELECT tag_id FROM Tags WHERE tag_name = 'π'");
int64 tag2_id = QueryInt("SELECT tag_id FROM Tags WHERE tag_name = 'φ'");
EXPECT_GT(tag1_id, 0LL);
EXPECT_GT(tag2_id, 0LL);
EXPECT_EQ(123.456, QueryDouble(strings::StrCat(
- "SELECT computed_time FROM Tensors WHERE tag_id = ",
+ "SELECT computed_time FROM Tensors WHERE series = ",
tag1_id, " AND step = 7")));
EXPECT_EQ(123.456, QueryDouble(strings::StrCat(
- "SELECT computed_time FROM Tensors WHERE tag_id = ",
+ "SELECT computed_time FROM Tensors WHERE series = ",
tag2_id, " AND step = 7")));
- EXPECT_NEAR(3.14,
- QueryDouble(strings::StrCat(
- "SELECT tensor FROM Tensors WHERE tag_id = ", tag1_id,
- " AND step = 7")),
- kTolerance); // Summary::simple_value is float
- EXPECT_NEAR(1.61,
- QueryDouble(strings::StrCat(
- "SELECT tensor FROM Tensors WHERE tag_id = ", tag2_id,
- " AND step = 7")),
- kTolerance);
}
TEST_F(SummaryDbWriterTest, WriteGraph) {
TF_ASSERT_OK(CreateSummaryDbWriter(db_, "", "R", "", &env_, &writer_));
env_.AdvanceByMillis(23);
GraphDef graph;
+ graph.mutable_library()->add_gradient()->set_function_name("funk");
NodeDef* node = graph.add_node();
node->set_name("x");
node->set_op("Placeholder");
@@ -260,11 +248,17 @@ TEST_F(SummaryDbWriterTest, WriteGraph) {
ASSERT_EQ(4LL, QueryInt("SELECT COUNT(*) FROM Nodes"));
ASSERT_EQ(3LL, QueryInt("SELECT COUNT(*) FROM NodeInputs"));
+ ASSERT_EQ(QueryInt("SELECT run_id FROM Runs"),
+ QueryInt("SELECT run_id FROM Graphs"));
+
int64 graph_id = QueryInt("SELECT graph_id FROM Graphs");
EXPECT_GT(graph_id, 0LL);
- EXPECT_EQ(graph_id, QueryInt("SELECT graph_id FROM Runs"));
EXPECT_EQ(0.023, QueryDouble("SELECT inserted_time FROM Graphs"));
- EXPECT_FALSE(QueryString("SELECT graph_def FROM Graphs").empty());
+
+ GraphDef graph2;
+ graph2.ParseFromString(QueryString("SELECT graph_def FROM Graphs"));
+ EXPECT_EQ(0, graph2.node_size());
+ EXPECT_EQ("funk", graph2.library().gradient(0).function_name());
EXPECT_EQ("x", QueryString("SELECT node_name FROM Nodes WHERE node_id = 0"));
EXPECT_EQ("y", QueryString("SELECT node_name FROM Nodes WHERE node_id = 1"));
@@ -307,33 +301,6 @@ TEST_F(SummaryDbWriterTest, WriteGraph) {
EXPECT_EQ(1LL, QueryInt("SELECT is_control FROM NodeInputs WHERE idx = 2"));
}
-TEST_F(SummaryDbWriterTest, WriteScalarInt32_CoercesToInt64) {
- TF_ASSERT_OK(CreateSummaryDbWriter(db_, "", "", "", &env_, &writer_));
- Tensor t(DT_INT32, {});
- t.scalar<int32>()() = -17;
- TF_ASSERT_OK(writer_->WriteScalar(1, t, "t"));
- TF_ASSERT_OK(writer_->Flush());
- ASSERT_EQ(-17LL, QueryInt("SELECT tensor FROM Tensors"));
-}
-
-TEST_F(SummaryDbWriterTest, WriteScalarInt8_CoercesToInt64) {
- TF_ASSERT_OK(CreateSummaryDbWriter(db_, "", "", "", &env_, &writer_));
- Tensor t(DT_INT8, {});
- t.scalar<int8>()() = static_cast<int8>(-17);
- TF_ASSERT_OK(writer_->WriteScalar(1, t, "t"));
- TF_ASSERT_OK(writer_->Flush());
- ASSERT_EQ(-17LL, QueryInt("SELECT tensor FROM Tensors"));
-}
-
-TEST_F(SummaryDbWriterTest, WriteScalarUint8_CoercesToInt64) {
- TF_ASSERT_OK(CreateSummaryDbWriter(db_, "", "", "", &env_, &writer_));
- Tensor t(DT_UINT8, {});
- t.scalar<uint8>()() = static_cast<uint8>(254);
- TF_ASSERT_OK(writer_->WriteScalar(1, t, "t"));
- TF_ASSERT_OK(writer_->Flush());
- ASSERT_EQ(254LL, QueryInt("SELECT tensor FROM Tensors"));
-}
-
TEST_F(SummaryDbWriterTest, UsesIdsTable) {
SummaryMetadata metadata;
TF_ASSERT_OK(CreateSummaryDbWriter(db_, "mad-science", "train", "jart", &env_,