aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/contrib/tensorboard
diff options
context:
space:
mode:
authorGravatar Justine Tunney <jart@google.com>2017-11-15 11:31:43 -0800
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2017-11-15 11:35:49 -0800
commit6fb721d608c4cd3855fe8793099a629428b9853c (patch)
treefaef08ed8bac4f5a8b065825a4405ef8a12e875f /tensorflow/contrib/tensorboard
parentb7b183b90aee8a4f4808f7d90a2c7a54a942e640 (diff)
Add graph writer op to contrib/summary
This change also defines a simple SQL data model for tf.GraphDef, which should move us closer to a world where TensorBoard can render the graph explorer without having to download the entire thing to the browser, as that could potentially be hundreds of megabytes. PiperOrigin-RevId: 175854921
Diffstat (limited to 'tensorflow/contrib/tensorboard')
-rw-r--r--tensorflow/contrib/tensorboard/db/schema.cc141
-rw-r--r--tensorflow/contrib/tensorboard/db/summary_db_writer.cc272
-rw-r--r--tensorflow/contrib/tensorboard/db/summary_db_writer_test.cc78
3 files changed, 404 insertions, 87 deletions
diff --git a/tensorflow/contrib/tensorboard/db/schema.cc b/tensorflow/contrib/tensorboard/db/schema.cc
index 98fff9e0ae..d63b2c6cc2 100644
--- a/tensorflow/contrib/tensorboard/db/schema.cc
+++ b/tensorflow/contrib/tensorboard/db/schema.cc
@@ -135,8 +135,7 @@ class SqliteSchema {
/// the database. This field will be mutated if the run is
/// restarted.
/// description: Optional markdown information.
- /// graph: Snappy tf.GraphDef proto with node field cleared. That
- /// field can be recreated using GraphNodes and NodeDefs.
+ /// graph_id: ID of associated Graphs row.
Status CreateRunsTable() {
return Run(R"sql(
CREATE TABLE IF NOT EXISTS Runs (
@@ -147,7 +146,7 @@ class SqliteSchema {
inserted_time REAL,
started_time REAL,
description TEXT,
- graph BLOB
+ graph_id INTEGER
)
)sql");
}
@@ -205,46 +204,78 @@ class SqliteSchema {
)sql");
}
- /// \brief Creates NodeDefs table.
- ///
- /// This table stores NodeDef protos which define the GraphDef for a
- /// Run. This functions like a hash table so rows can be shared by
- /// multiple Runs in an Experiment.
+ /// \brief Creates Graphs table.
///
/// Fields:
/// rowid: Ephemeral b-tree ID dictating locality.
- /// experiment_id: Optional int64 for grouping rows.
- /// node_def_id: Permanent >0 unique ID.
- /// fingerprint: Optional farmhash::Fingerprint64() of uncompressed
- /// node_def bytes, coerced to int64.
- /// node_def: BLOB containing a Snappy tf.NodeDef proto.
- Status CreateNodeDefsTable() {
+ /// graph_id: Permanent >0 unique ID.
+ /// inserted_time: Float UNIX timestamp with µs precision. This is
+ /// always the wall time of when the row was inserted into the
+ /// DB. It may be used as a hint for an archival job.
+ /// node_def: Contains Snappy tf.GraphDef proto. All fields will be
+ /// cleared except those not expressed in SQL.
+ Status CreateGraphsTable() {
return Run(R"sql(
- CREATE TABLE IF NOT EXISTS NodeDefs (
+ CREATE TABLE IF NOT EXISTS Graphs (
rowid INTEGER PRIMARY KEY,
- experiment_id INTEGER,
- node_def_id INTEGER NOT NULL,
- fingerprint INTEGER,
- node_def TEXT
+ graph_id INTEGER NOT NULL,
+ inserted_time REAL,
+ graph_def BLOB
)
)sql");
}
- /// \brief Creates RunNodeDefs table.
+ /// \brief Creates Nodes table.
///
- /// Table mapping Runs to NodeDefs. This is used to recreate the node
- /// field of the GraphDef proto.
+ /// Fields:
+ /// rowid: Ephemeral b-tree ID dictating locality.
+ /// graph_id: Permanent >0 unique ID.
+ /// node_id: ID for this node. This is more like a 0-index within
+ /// the Graph. Please note indexes are allowed to be removed.
+ /// node_name: Unique name for this Node within Graph. This is
+ /// copied from the proto so it can be indexed. This is allowed
+ /// to be NULL to save space on the index, in which case the
+ /// node_def.name proto field must not be cleared.
+ /// op: Copied from tf.NodeDef proto.
+ /// device: Copied from tf.NodeDef proto.
+ /// node_def: Contains Snappy tf.NodeDef proto. All fields will be
+ /// cleared except those not expressed in SQL.
+ Status CreateNodesTable() {
+ return Run(R"sql(
+ CREATE TABLE IF NOT EXISTS Nodes (
+ rowid INTEGER PRIMARY KEY,
+ graph_id INTEGER NOT NULL,
+ node_id INTEGER NOT NULL,
+ node_name TEXT,
+ op TEXT,
+ device TEXT,
+ node_def BLOB
+ )
+ )sql");
+ }
+
+ /// \brief Creates NodeInputs table.
///
/// Fields:
/// rowid: Ephemeral b-tree ID dictating locality.
- /// run_id: Mandatory ID of associated Run.
- /// node_def_id: Mandatory ID of associated NodeDef.
- Status CreateRunNodeDefsTable() {
+ /// graph_id: Permanent >0 unique ID.
+ /// node_id: Index of Node in question. This can be considered the
+ /// 'to' vertex.
+ /// idx: Used for ordering inputs on a given Node.
+ /// input_node_id: Nodes.node_id of the corresponding input node.
+ /// This can be considered the 'from' vertex.
+ /// is_control: If non-zero, indicates this input is a controlled
+ /// dependency, which means this isn't an edge through which
+ /// tensors flow. NULL means 0.
+ Status CreateNodeInputsTable() {
return Run(R"sql(
- CREATE TABLE IF NOT EXISTS RunNodeDefs (
+ CREATE TABLE IF NOT EXISTS NodeInputs (
rowid INTEGER PRIMARY KEY,
- run_id INTEGER NOT NULL,
- node_def_id INTEGER NOT NULL
+ graph_id INTEGER NOT NULL,
+ node_id INTEGER NOT NULL,
+ idx INTEGER NOT NULL,
+ input_node_id INTEGER NOT NULL,
+ is_control INTEGER
)
)sql");
}
@@ -297,11 +328,27 @@ class SqliteSchema {
)sql");
}
- /// \brief Uniquely indexes node_def_id on NodeDefs table.
- Status CreateNodeDefIdIndex() {
+ /// \brief Uniquely indexes graph_id on Graphs table.
+ Status CreateGraphIdIndex() {
return Run(R"sql(
- CREATE UNIQUE INDEX IF NOT EXISTS NodeDefIdIndex
- ON NodeDefs (node_def_id)
+ CREATE UNIQUE INDEX IF NOT EXISTS GraphIdIndex
+ ON Graphs (graph_id)
+ )sql");
+ }
+
+ /// \brief Uniquely indexes (graph_id, node_id) on Nodes table.
+ Status CreateNodeIdIndex() {
+ return Run(R"sql(
+ CREATE UNIQUE INDEX IF NOT EXISTS NodeIdIndex
+ ON Nodes (graph_id, node_id)
+ )sql");
+ }
+
+ /// \brief Uniquely indexes (graph_id, node_id, idx) on NodeInputs table.
+ Status CreateNodeInputsIndex() {
+ return Run(R"sql(
+ CREATE UNIQUE INDEX IF NOT EXISTS NodeInputsIndex
+ ON NodeInputs (graph_id, node_id, idx)
)sql");
}
@@ -350,20 +397,12 @@ class SqliteSchema {
)sql");
}
- /// \brief Indexes (experiment_id, fingerprint) on NodeDefs table.
- Status CreateNodeDefFingerprintIndex() {
- return Run(R"sql(
- CREATE INDEX IF NOT EXISTS NodeDefFingerprintIndex
- ON NodeDefs (experiment_id, fingerprint)
- WHERE fingerprint IS NOT NULL
- )sql");
- }
-
- /// \brief Uniquely indexes (run_id, node_def_id) on RunNodeDefs table.
- Status CreateRunNodeDefIndex() {
+ /// \brief Uniquely indexes (graph_id, node_name) on Nodes table.
+ Status CreateNodeNameIndex() {
return Run(R"sql(
- CREATE UNIQUE INDEX IF NOT EXISTS RunNodeDefIndex
- ON RunNodeDefs (run_id, node_def_id)
+ CREATE UNIQUE INDEX IF NOT EXISTS NodeNameIndex
+ ON Nodes (graph_id, node_name)
+ WHERE node_name IS NOT NULL
)sql");
}
@@ -387,22 +426,24 @@ Status SetupTensorboardSqliteDb(std::shared_ptr<Sqlite> db) {
TF_RETURN_IF_ERROR(s.CreateRunsTable());
TF_RETURN_IF_ERROR(s.CreateExperimentsTable());
TF_RETURN_IF_ERROR(s.CreateUsersTable());
- TF_RETURN_IF_ERROR(s.CreateNodeDefsTable());
- TF_RETURN_IF_ERROR(s.CreateRunNodeDefsTable());
+ TF_RETURN_IF_ERROR(s.CreateGraphsTable());
+ TF_RETURN_IF_ERROR(s.CreateNodeInputsTable());
+ TF_RETURN_IF_ERROR(s.CreateNodesTable());
TF_RETURN_IF_ERROR(s.CreateTensorIndex());
TF_RETURN_IF_ERROR(s.CreateTensorChunkIndex());
TF_RETURN_IF_ERROR(s.CreateTagIdIndex());
TF_RETURN_IF_ERROR(s.CreateRunIdIndex());
TF_RETURN_IF_ERROR(s.CreateExperimentIdIndex());
TF_RETURN_IF_ERROR(s.CreateUserIdIndex());
- TF_RETURN_IF_ERROR(s.CreateNodeDefIdIndex());
+ TF_RETURN_IF_ERROR(s.CreateGraphIdIndex());
+ TF_RETURN_IF_ERROR(s.CreateNodeIdIndex());
+ TF_RETURN_IF_ERROR(s.CreateNodeInputsIndex());
TF_RETURN_IF_ERROR(s.CreateTagNameIndex());
TF_RETURN_IF_ERROR(s.CreateRunNameIndex());
TF_RETURN_IF_ERROR(s.CreateExperimentNameIndex());
TF_RETURN_IF_ERROR(s.CreateUserNameIndex());
TF_RETURN_IF_ERROR(s.CreateUserEmailIndex());
- TF_RETURN_IF_ERROR(s.CreateNodeDefFingerprintIndex());
- TF_RETURN_IF_ERROR(s.CreateRunNodeDefIndex());
+ TF_RETURN_IF_ERROR(s.CreateNodeNameIndex());
return Status::OK();
}
diff --git a/tensorflow/contrib/tensorboard/db/summary_db_writer.cc b/tensorflow/contrib/tensorboard/db/summary_db_writer.cc
index a26ad61660..ae063d24ef 100644
--- a/tensorflow/contrib/tensorboard/db/summary_db_writer.cc
+++ b/tensorflow/contrib/tensorboard/db/summary_db_writer.cc
@@ -15,17 +15,29 @@ limitations under the License.
#include "tensorflow/contrib/tensorboard/db/summary_db_writer.h"
#include "tensorflow/contrib/tensorboard/db/schema.h"
+#include "tensorflow/core/framework/graph.pb.h"
+#include "tensorflow/core/framework/node_def.pb.h"
#include "tensorflow/core/framework/summary.pb.h"
+#include "tensorflow/core/lib/core/stringpiece.h"
#include "tensorflow/core/lib/db/sqlite.h"
#include "tensorflow/core/lib/random/random.h"
#include "tensorflow/core/lib/strings/stringprintf.h"
+#include "tensorflow/core/platform/fingerprint.h"
#include "tensorflow/core/platform/snappy.h"
#include "tensorflow/core/util/event.pb.h"
namespace tensorflow {
namespace {
+double GetWallTime(Env* env) {
+ // TODO(@jart): Follow precise definitions for time laid out in schema.
+ // TODO(@jart): Use monotonic clock from gRPC codebase.
+ return static_cast<double>(env->NowMicros()) / 1.0e6;
+}
+
int64 MakeRandomId() {
+ // TODO(@jart): Try generating ID in 2^24 space, falling back to 2^63
+ // https://sqlite.org/src4/doc/trunk/www/varint.wiki
int64 id = static_cast<int64>(random::New64() & ((1ULL << 63) - 1));
if (id == 0) {
++id;
@@ -33,10 +45,201 @@ int64 MakeRandomId() {
return id;
}
+Status Serialize(const protobuf::MessageLite& proto, string* output) {
+ output->clear();
+ if (!proto.SerializeToString(output)) {
+ return errors::DataLoss("SerializeToString failed");
+ }
+ return Status::OK();
+}
+
+Status Compress(const string& data, string* output) {
+ output->clear();
+ if (!port::Snappy_Compress(data.data(), data.size(), output)) {
+ return errors::FailedPrecondition("TensorBase needs Snappy");
+ }
+ return Status::OK();
+}
+
+Status BindProto(SqliteStatement* stmt, int parameter,
+ const protobuf::MessageLite& proto) {
+ string serialized;
+ TF_RETURN_IF_ERROR(Serialize(proto, &serialized));
+ string compressed;
+ TF_RETURN_IF_ERROR(Compress(serialized, &compressed));
+ stmt->BindBlobUnsafe(parameter, compressed);
+ return Status::OK();
+}
+
+Status BindTensor(SqliteStatement* stmt, int parameter, const Tensor& t) {
+ // TODO(@jart): Make portable between little and big endian systems.
+ // TODO(@jart): Use TensorChunks with minimal copying for big tensors.
+ // TODO(@jart): Add field to indicate encoding.
+ // TODO(@jart): Allow crunch tool to re-compress with zlib instead.
+ TensorProto p;
+ t.AsProtoTensorContent(&p);
+ return BindProto(stmt, parameter, p);
+}
+
+class Transactor {
+ public:
+ explicit Transactor(std::shared_ptr<Sqlite> db)
+ : db_(std::move(db)),
+ begin_(db_->Prepare("BEGIN TRANSACTION")),
+ commit_(db_->Prepare("COMMIT TRANSACTION")),
+ rollback_(db_->Prepare("ROLLBACK TRANSACTION")) {}
+
+ template <typename T, typename... Args>
+ Status Transact(T callback, Args&&... args) {
+ TF_RETURN_IF_ERROR(begin_.StepAndReset());
+ Status s = callback(std::forward<Args>(args)...);
+ if (s.ok()) {
+ TF_RETURN_IF_ERROR(commit_.StepAndReset());
+ } else {
+ TF_RETURN_WITH_CONTEXT_IF_ERROR(rollback_.StepAndReset(), s.ToString());
+ }
+ return s;
+ }
+
+ private:
+ std::shared_ptr<Sqlite> db_;
+ SqliteStatement begin_;
+ SqliteStatement commit_;
+ SqliteStatement rollback_;
+};
+
+class GraphSaver {
+ public:
+ static Status SaveToRun(Env* env, Sqlite* db, GraphDef* graph, int64 run_id) {
+ auto get = db->Prepare("SELECT graph_id FROM Runs WHERE run_id = ?");
+ get.BindInt(1, run_id);
+ bool is_done;
+ TF_RETURN_IF_ERROR(get.Step(&is_done));
+ int64 graph_id = is_done ? 0 : get.ColumnInt(0);
+ if (graph_id == 0) {
+ graph_id = MakeRandomId();
+ // TODO(@jart): Check for ID collision.
+ auto set = db->Prepare("UPDATE Runs SET graph_id = ? WHERE run_id = ?");
+ set.BindInt(1, graph_id);
+ set.BindInt(2, run_id);
+ TF_RETURN_IF_ERROR(set.StepAndReset());
+ }
+ return Save(env, db, graph, graph_id);
+ }
+
+ static Status Save(Env* env, Sqlite* db, GraphDef* graph, int64 graph_id) {
+ GraphSaver saver{env, db, graph, graph_id};
+ saver.MapNameToNodeId();
+ TF_RETURN_IF_ERROR(saver.SaveNodeInputs());
+ TF_RETURN_IF_ERROR(saver.SaveNodes());
+ TF_RETURN_IF_ERROR(saver.SaveGraph());
+ return Status::OK();
+ }
+
+ private:
+ GraphSaver(Env* env, Sqlite* db, GraphDef* graph, int64 graph_id)
+ : env_(env), db_(db), graph_(graph), graph_id_(graph_id) {}
+
+ void MapNameToNodeId() {
+ size_t toto = static_cast<size_t>(graph_->node_size());
+ name_copies_.reserve(toto);
+ name_to_node_id_.reserve(toto);
+ for (int node_id = 0; node_id < graph_->node_size(); ++node_id) {
+ // Copy name into memory region, since we call clear_name() later.
+ // Then wrap in StringPiece so we can compare slices without copy.
+ name_copies_.emplace_back(graph_->node(node_id).name());
+ name_to_node_id_.emplace(name_copies_.back(), node_id);
+ }
+ }
+
+ Status SaveNodeInputs() {
+ auto purge = db_->Prepare("DELETE FROM NodeInputs WHERE graph_id = ?");
+ purge.BindInt(1, graph_id_);
+ TF_RETURN_IF_ERROR(purge.StepAndReset());
+ auto insert = db_->Prepare(R"sql(
+ INSERT INTO NodeInputs (graph_id, node_id, idx, input_node_id, is_control)
+ VALUES (?, ?, ?, ?, ?)
+ )sql");
+ for (int node_id = 0; node_id < graph_->node_size(); ++node_id) {
+ const NodeDef& node = graph_->node(node_id);
+ for (int idx = 0; idx < node.input_size(); ++idx) {
+ StringPiece name = node.input(idx);
+ insert.BindInt(1, graph_id_);
+ insert.BindInt(2, node_id);
+ insert.BindInt(3, idx);
+ if (!name.empty() && name[0] == '^') {
+ name.remove_prefix(1);
+ insert.BindInt(5, 1);
+ }
+ auto e = name_to_node_id_.find(name);
+ if (e == name_to_node_id_.end()) {
+ return errors::DataLoss("Could not find node: ", name);
+ }
+ insert.BindInt(4, e->second);
+ TF_RETURN_WITH_CONTEXT_IF_ERROR(insert.StepAndReset(), node.name(),
+ " -> ", name);
+ }
+ }
+ return Status::OK();
+ }
+
+ Status SaveNodes() {
+ auto purge = db_->Prepare("DELETE FROM Nodes WHERE graph_id = ?");
+ purge.BindInt(1, graph_id_);
+ TF_RETURN_IF_ERROR(purge.StepAndReset());
+ auto insert = db_->Prepare(R"sql(
+ INSERT INTO Nodes (graph_id, node_id, node_name, op, device, node_def)
+ VALUES (?, ?, ?, ?, ?, ?)
+ )sql");
+ for (int node_id = 0; node_id < graph_->node_size(); ++node_id) {
+ NodeDef* node = graph_->mutable_node(node_id);
+ insert.BindInt(1, graph_id_);
+ insert.BindInt(2, node_id);
+ insert.BindText(3, node->name());
+ node->clear_name();
+ if (!node->op().empty()) {
+ insert.BindText(4, node->op());
+ node->clear_op();
+ }
+ if (!node->device().empty()) {
+ insert.BindText(5, node->device());
+ node->clear_device();
+ }
+ node->clear_input();
+ TF_RETURN_IF_ERROR(BindProto(&insert, 6, *node));
+ TF_RETURN_WITH_CONTEXT_IF_ERROR(insert.StepAndReset(), node->name());
+ }
+ return Status::OK();
+ }
+
+ Status SaveGraph() {
+ auto insert = db_->Prepare(R"sql(
+ INSERT OR REPLACE INTO Graphs (graph_id, inserted_time, graph_def)
+ VALUES (?, ?, ?)
+ )sql");
+ insert.BindInt(1, graph_id_);
+ insert.BindDouble(2, GetWallTime(env_));
+ graph_->clear_node();
+ TF_RETURN_IF_ERROR(BindProto(&insert, 3, *graph_));
+ return insert.StepAndReset();
+ }
+
+ Env* env_;
+ Sqlite* db_;
+ GraphDef* graph_;
+ int64 graph_id_;
+ std::vector<string> name_copies_;
+ std::unordered_map<StringPiece, int64, StringPiece::Hasher> name_to_node_id_;
+};
+
class SummaryDbWriter : public SummaryWriterInterface {
public:
SummaryDbWriter(Env* env, std::shared_ptr<Sqlite> db)
- : SummaryWriterInterface(), env_(env), db_(std::move(db)), run_id_(-1) {}
+ : SummaryWriterInterface(),
+ env_(env),
+ db_(std::move(db)),
+ txn_(db_),
+ run_id_{0LL} {}
~SummaryDbWriter() override {}
Status Initialize(const string& experiment_name, const string& run_name,
@@ -76,7 +279,7 @@ class SummaryDbWriter : public SummaryWriterInterface {
// TODO(@jart): Check for random ID collisions without needing txn retry.
insert_tensor_.BindInt(1, tag_id);
insert_tensor_.BindInt(2, global_step);
- insert_tensor_.BindDouble(3, GetWallTime());
+ insert_tensor_.BindDouble(3, GetWallTime(env_));
switch (t.dtype()) {
case DT_INT64:
insert_tensor_.BindInt(4, t.scalar<int64>()());
@@ -85,22 +288,41 @@ class SummaryDbWriter : public SummaryWriterInterface {
insert_tensor_.BindDouble(4, t.scalar<double>()());
break;
default:
- TF_RETURN_IF_ERROR(BindTensor(t));
+ TF_RETURN_IF_ERROR(BindTensor(&insert_tensor_, 4, t));
break;
}
return insert_tensor_.StepAndReset();
}
- Status WriteEvent(std::unique_ptr<Event> e) override {
+ Status WriteGraph(int64 global_step, std::unique_ptr<GraphDef> g) override {
mutex_lock ml(mu_);
TF_RETURN_IF_ERROR(InitializeParents());
- if (e->what_case() == Event::WhatCase::kSummary) {
- const Summary& summary = e->summary();
- for (int i = 0; i < summary.value_size(); ++i) {
- TF_RETURN_IF_ERROR(WriteSummary(e.get(), summary.value(i)));
+ return txn_.Transact(GraphSaver::SaveToRun, env_, db_.get(), g.get(),
+ run_id_);
+ }
+
+ Status WriteEvent(std::unique_ptr<Event> e) override {
+ switch (e->what_case()) {
+ case Event::WhatCase::kSummary: {
+ mutex_lock ml(mu_);
+ TF_RETURN_IF_ERROR(InitializeParents());
+ const Summary& summary = e->summary();
+ for (int i = 0; i < summary.value_size(); ++i) {
+ TF_RETURN_IF_ERROR(WriteSummary(e.get(), summary.value(i)));
+ }
+ return Status::OK();
}
+ case Event::WhatCase::kGraphDef: {
+ std::unique_ptr<GraphDef> graph{new GraphDef};
+ if (!ParseProtoUnlimited(graph.get(), e->graph_def())) {
+ return errors::DataLoss("parse event.graph_def failed");
+ }
+ return WriteGraph(e->step(), std::move(graph));
+ }
+ default:
+ // TODO(@jart): Handle other stuff.
+ return Status::OK();
}
- return Status::OK();
}
Status WriteScalar(int64 global_step, Tensor t, const string& tag) override {
@@ -136,33 +358,8 @@ class SummaryDbWriter : public SummaryWriterInterface {
string DebugString() override { return "SummaryDbWriter"; }
private:
- double GetWallTime() {
- // TODO(@jart): Follow precise definitions for time laid out in schema.
- // TODO(@jart): Use monotonic clock from gRPC codebase.
- return static_cast<double>(env_->NowMicros()) / 1.0e6;
- }
-
- Status BindTensor(const Tensor& t) EXCLUSIVE_LOCKS_REQUIRED(mu_) {
- // TODO(@jart): Make portable between little and big endian systems.
- // TODO(@jart): Use TensorChunks with minimal copying for big tensors.
- TensorProto p;
- t.AsProtoTensorContent(&p);
- string encoded;
- if (!p.SerializeToString(&encoded)) {
- return errors::DataLoss("SerializeToString failed");
- }
- // TODO(@jart): Put byte at beginning of blob to indicate encoding.
- // TODO(@jart): Allow crunch tool to re-compress with zlib instead.
- string compressed;
- if (!port::Snappy_Compress(encoded.data(), encoded.size(), &compressed)) {
- return errors::FailedPrecondition("TensorBase needs Snappy");
- }
- insert_tensor_.BindBlobUnsafe(4, compressed);
- return Status::OK();
- }
-
Status InitializeParents() EXCLUSIVE_LOCKS_REQUIRED(mu_) {
- if (run_id_ >= 0) {
+ if (run_id_ > 0) {
return Status::OK();
}
int64 user_id;
@@ -195,7 +392,7 @@ class SummaryDbWriter : public SummaryWriterInterface {
)sql");
insert_user.BindInt(1, *user_id);
insert_user.BindText(2, user_name);
- insert_user.BindDouble(3, GetWallTime());
+ insert_user.BindDouble(3, GetWallTime(env_));
TF_RETURN_IF_ERROR(insert_user.StepAndReset());
}
return Status::OK();
@@ -249,7 +446,7 @@ class SummaryDbWriter : public SummaryWriterInterface {
}
insert.BindInt(2, *id);
insert.BindText(3, name);
- insert.BindDouble(4, GetWallTime());
+ insert.BindDouble(4, GetWallTime(env_));
TF_RETURN_IF_ERROR(insert.StepAndReset());
}
return Status::OK();
@@ -276,6 +473,7 @@ class SummaryDbWriter : public SummaryWriterInterface {
mutex mu_;
Env* env_;
std::shared_ptr<Sqlite> db_ GUARDED_BY(mu_);
+ Transactor txn_ GUARDED_BY(mu_);
SqliteStatement insert_tensor_ GUARDED_BY(mu_);
SqliteStatement update_metadata_ GUARDED_BY(mu_);
string user_name_ GUARDED_BY(mu_);
diff --git a/tensorflow/contrib/tensorboard/db/summary_db_writer_test.cc b/tensorflow/contrib/tensorboard/db/summary_db_writer_test.cc
index c1af51e7b7..3431842ca2 100644
--- a/tensorflow/contrib/tensorboard/db/summary_db_writer_test.cc
+++ b/tensorflow/contrib/tensorboard/db/summary_db_writer_test.cc
@@ -14,6 +14,8 @@ limitations under the License.
==============================================================================*/
#include "tensorflow/contrib/tensorboard/db/summary_db_writer.h"
+#include "tensorflow/core/framework/graph.pb.h"
+#include "tensorflow/core/framework/node_def.pb.h"
#include "tensorflow/core/framework/summary.pb.h"
#include "tensorflow/core/lib/core/status_test_util.h"
#include "tensorflow/core/lib/db/sqlite.h"
@@ -212,5 +214,81 @@ TEST_F(SummaryDbWriterTest, WriteEvent_Scalar) {
kTolerance);
}
+TEST_F(SummaryDbWriterTest, WriteGraph) {
+ TF_ASSERT_OK(CreateSummaryDbWriter(db_, "", "R", "", &env_, &writer_));
+ env_.AdvanceByMillis(23);
+ GraphDef graph;
+ NodeDef* node = graph.add_node();
+ node->set_name("x");
+ node->set_op("Placeholder");
+ node = graph.add_node();
+ node->set_name("y");
+ node->set_op("Placeholder");
+ node = graph.add_node();
+ node->set_name("z");
+ node->set_op("Love");
+ node = graph.add_node();
+ node->set_name("+");
+ node->set_op("Add");
+ node->add_input("x");
+ node->add_input("y");
+ node->add_input("^z");
+ node->set_device("tpu/lol");
+ std::unique_ptr<Event> e{new Event};
+ graph.SerializeToString(e->mutable_graph_def());
+ TF_ASSERT_OK(writer_->WriteEvent(std::move(e)));
+ TF_ASSERT_OK(writer_->Flush());
+ ASSERT_EQ(1LL, QueryInt("SELECT COUNT(*) FROM Runs"));
+ ASSERT_EQ(1LL, QueryInt("SELECT COUNT(*) FROM Graphs"));
+ ASSERT_EQ(4LL, QueryInt("SELECT COUNT(*) FROM Nodes"));
+ ASSERT_EQ(3LL, QueryInt("SELECT COUNT(*) FROM NodeInputs"));
+
+ int64 graph_id = QueryInt("SELECT graph_id FROM Graphs");
+ EXPECT_GT(graph_id, 0LL);
+ EXPECT_EQ(graph_id, QueryInt("SELECT graph_id FROM Runs"));
+ EXPECT_EQ(0.023, QueryDouble("SELECT inserted_time FROM Graphs"));
+ EXPECT_FALSE(QueryString("SELECT graph_def FROM Graphs").empty());
+
+ EXPECT_EQ("x", QueryString("SELECT node_name FROM Nodes WHERE node_id = 0"));
+ EXPECT_EQ("y", QueryString("SELECT node_name FROM Nodes WHERE node_id = 1"));
+ EXPECT_EQ("z", QueryString("SELECT node_name FROM Nodes WHERE node_id = 2"));
+ EXPECT_EQ("+", QueryString("SELECT node_name FROM Nodes WHERE node_id = 3"));
+
+ EXPECT_EQ("Placeholder",
+ QueryString("SELECT op FROM Nodes WHERE node_id = 0"));
+ EXPECT_EQ("Placeholder",
+ QueryString("SELECT op FROM Nodes WHERE node_id = 1"));
+ EXPECT_EQ("Love", QueryString("SELECT op FROM Nodes WHERE node_id = 2"));
+ EXPECT_EQ("Add", QueryString("SELECT op FROM Nodes WHERE node_id = 3"));
+
+ EXPECT_EQ("", QueryString("SELECT device FROM Nodes WHERE node_id = 0"));
+ EXPECT_EQ("", QueryString("SELECT device FROM Nodes WHERE node_id = 1"));
+ EXPECT_EQ("", QueryString("SELECT device FROM Nodes WHERE node_id = 2"));
+ EXPECT_EQ("tpu/lol",
+ QueryString("SELECT device FROM Nodes WHERE node_id = 3"));
+
+ EXPECT_EQ(graph_id,
+ QueryInt("SELECT graph_id FROM NodeInputs WHERE idx = 0"));
+ EXPECT_EQ(graph_id,
+ QueryInt("SELECT graph_id FROM NodeInputs WHERE idx = 1"));
+ EXPECT_EQ(graph_id,
+ QueryInt("SELECT graph_id FROM NodeInputs WHERE idx = 2"));
+
+ EXPECT_EQ(3LL, QueryInt("SELECT node_id FROM NodeInputs WHERE idx = 0"));
+ EXPECT_EQ(3LL, QueryInt("SELECT node_id FROM NodeInputs WHERE idx = 1"));
+ EXPECT_EQ(3LL, QueryInt("SELECT node_id FROM NodeInputs WHERE idx = 2"));
+
+ EXPECT_EQ(0LL,
+ QueryInt("SELECT input_node_id FROM NodeInputs WHERE idx = 0"));
+ EXPECT_EQ(1LL,
+ QueryInt("SELECT input_node_id FROM NodeInputs WHERE idx = 1"));
+ EXPECT_EQ(2LL,
+ QueryInt("SELECT input_node_id FROM NodeInputs WHERE idx = 2"));
+
+ EXPECT_EQ(0LL, QueryInt("SELECT is_control FROM NodeInputs WHERE idx = 0"));
+ EXPECT_EQ(0LL, QueryInt("SELECT is_control FROM NodeInputs WHERE idx = 1"));
+ EXPECT_EQ(1LL, QueryInt("SELECT is_control FROM NodeInputs WHERE idx = 2"));
+}
+
} // namespace
} // namespace tensorflow