diff options
author | Justine Tunney <jart@google.com> | 2017-11-15 11:31:43 -0800 |
---|---|---|
committer | TensorFlower Gardener <gardener@tensorflow.org> | 2017-11-15 11:35:49 -0800 |
commit | 6fb721d608c4cd3855fe8793099a629428b9853c (patch) | |
tree | faef08ed8bac4f5a8b065825a4405ef8a12e875f /tensorflow/contrib/tensorboard | |
parent | b7b183b90aee8a4f4808f7d90a2c7a54a942e640 (diff) |
Add graph writer op to contrib/summary
This change also defines a simple SQL data model for tf.GraphDef, which
should move us closer to a world where TensorBoard can render the graph
explorer without having to download the entire thing to the browser, as
that could potentially be hundreds of megabytes.
PiperOrigin-RevId: 175854921
Diffstat (limited to 'tensorflow/contrib/tensorboard')
-rw-r--r-- | tensorflow/contrib/tensorboard/db/schema.cc | 141 | ||||
-rw-r--r-- | tensorflow/contrib/tensorboard/db/summary_db_writer.cc | 272 | ||||
-rw-r--r-- | tensorflow/contrib/tensorboard/db/summary_db_writer_test.cc | 78 |
3 files changed, 404 insertions, 87 deletions
diff --git a/tensorflow/contrib/tensorboard/db/schema.cc b/tensorflow/contrib/tensorboard/db/schema.cc index 98fff9e0ae..d63b2c6cc2 100644 --- a/tensorflow/contrib/tensorboard/db/schema.cc +++ b/tensorflow/contrib/tensorboard/db/schema.cc @@ -135,8 +135,7 @@ class SqliteSchema { /// the database. This field will be mutated if the run is /// restarted. /// description: Optional markdown information. - /// graph: Snappy tf.GraphDef proto with node field cleared. That - /// field can be recreated using GraphNodes and NodeDefs. + /// graph_id: ID of associated Graphs row. Status CreateRunsTable() { return Run(R"sql( CREATE TABLE IF NOT EXISTS Runs ( @@ -147,7 +146,7 @@ class SqliteSchema { inserted_time REAL, started_time REAL, description TEXT, - graph BLOB + graph_id INTEGER ) )sql"); } @@ -205,46 +204,78 @@ class SqliteSchema { )sql"); } - /// \brief Creates NodeDefs table. - /// - /// This table stores NodeDef protos which define the GraphDef for a - /// Run. This functions like a hash table so rows can be shared by - /// multiple Runs in an Experiment. + /// \brief Creates Graphs table. /// /// Fields: /// rowid: Ephemeral b-tree ID dictating locality. - /// experiment_id: Optional int64 for grouping rows. - /// node_def_id: Permanent >0 unique ID. - /// fingerprint: Optional farmhash::Fingerprint64() of uncompressed - /// node_def bytes, coerced to int64. - /// node_def: BLOB containing a Snappy tf.NodeDef proto. - Status CreateNodeDefsTable() { + /// graph_id: Permanent >0 unique ID. + /// inserted_time: Float UNIX timestamp with µs precision. This is + /// always the wall time of when the row was inserted into the + /// DB. It may be used as a hint for an archival job. + /// node_def: Contains Snappy tf.GraphDef proto. All fields will be + /// cleared except those not expressed in SQL. + Status CreateGraphsTable() { return Run(R"sql( - CREATE TABLE IF NOT EXISTS NodeDefs ( + CREATE TABLE IF NOT EXISTS Graphs ( rowid INTEGER PRIMARY KEY, - experiment_id INTEGER, - node_def_id INTEGER NOT NULL, - fingerprint INTEGER, - node_def TEXT + graph_id INTEGER NOT NULL, + inserted_time REAL, + graph_def BLOB ) )sql"); } - /// \brief Creates RunNodeDefs table. + /// \brief Creates Nodes table. /// - /// Table mapping Runs to NodeDefs. This is used to recreate the node - /// field of the GraphDef proto. + /// Fields: + /// rowid: Ephemeral b-tree ID dictating locality. + /// graph_id: Permanent >0 unique ID. + /// node_id: ID for this node. This is more like a 0-index within + /// the Graph. Please note indexes are allowed to be removed. + /// node_name: Unique name for this Node within Graph. This is + /// copied from the proto so it can be indexed. This is allowed + /// to be NULL to save space on the index, in which case the + /// node_def.name proto field must not be cleared. + /// op: Copied from tf.NodeDef proto. + /// device: Copied from tf.NodeDef proto. + /// node_def: Contains Snappy tf.NodeDef proto. All fields will be + /// cleared except those not expressed in SQL. + Status CreateNodesTable() { + return Run(R"sql( + CREATE TABLE IF NOT EXISTS Nodes ( + rowid INTEGER PRIMARY KEY, + graph_id INTEGER NOT NULL, + node_id INTEGER NOT NULL, + node_name TEXT, + op TEXT, + device TEXT, + node_def BLOB + ) + )sql"); + } + + /// \brief Creates NodeInputs table. /// /// Fields: /// rowid: Ephemeral b-tree ID dictating locality. - /// run_id: Mandatory ID of associated Run. - /// node_def_id: Mandatory ID of associated NodeDef. - Status CreateRunNodeDefsTable() { + /// graph_id: Permanent >0 unique ID. + /// node_id: Index of Node in question. This can be considered the + /// 'to' vertex. + /// idx: Used for ordering inputs on a given Node. + /// input_node_id: Nodes.node_id of the corresponding input node. + /// This can be considered the 'from' vertex. + /// is_control: If non-zero, indicates this input is a controlled + /// dependency, which means this isn't an edge through which + /// tensors flow. NULL means 0. + Status CreateNodeInputsTable() { return Run(R"sql( - CREATE TABLE IF NOT EXISTS RunNodeDefs ( + CREATE TABLE IF NOT EXISTS NodeInputs ( rowid INTEGER PRIMARY KEY, - run_id INTEGER NOT NULL, - node_def_id INTEGER NOT NULL + graph_id INTEGER NOT NULL, + node_id INTEGER NOT NULL, + idx INTEGER NOT NULL, + input_node_id INTEGER NOT NULL, + is_control INTEGER ) )sql"); } @@ -297,11 +328,27 @@ class SqliteSchema { )sql"); } - /// \brief Uniquely indexes node_def_id on NodeDefs table. - Status CreateNodeDefIdIndex() { + /// \brief Uniquely indexes graph_id on Graphs table. + Status CreateGraphIdIndex() { return Run(R"sql( - CREATE UNIQUE INDEX IF NOT EXISTS NodeDefIdIndex - ON NodeDefs (node_def_id) + CREATE UNIQUE INDEX IF NOT EXISTS GraphIdIndex + ON Graphs (graph_id) + )sql"); + } + + /// \brief Uniquely indexes (graph_id, node_id) on Nodes table. + Status CreateNodeIdIndex() { + return Run(R"sql( + CREATE UNIQUE INDEX IF NOT EXISTS NodeIdIndex + ON Nodes (graph_id, node_id) + )sql"); + } + + /// \brief Uniquely indexes (graph_id, node_id, idx) on NodeInputs table. + Status CreateNodeInputsIndex() { + return Run(R"sql( + CREATE UNIQUE INDEX IF NOT EXISTS NodeInputsIndex + ON NodeInputs (graph_id, node_id, idx) )sql"); } @@ -350,20 +397,12 @@ class SqliteSchema { )sql"); } - /// \brief Indexes (experiment_id, fingerprint) on NodeDefs table. - Status CreateNodeDefFingerprintIndex() { - return Run(R"sql( - CREATE INDEX IF NOT EXISTS NodeDefFingerprintIndex - ON NodeDefs (experiment_id, fingerprint) - WHERE fingerprint IS NOT NULL - )sql"); - } - - /// \brief Uniquely indexes (run_id, node_def_id) on RunNodeDefs table. - Status CreateRunNodeDefIndex() { + /// \brief Uniquely indexes (graph_id, node_name) on Nodes table. + Status CreateNodeNameIndex() { return Run(R"sql( - CREATE UNIQUE INDEX IF NOT EXISTS RunNodeDefIndex - ON RunNodeDefs (run_id, node_def_id) + CREATE UNIQUE INDEX IF NOT EXISTS NodeNameIndex + ON Nodes (graph_id, node_name) + WHERE node_name IS NOT NULL )sql"); } @@ -387,22 +426,24 @@ Status SetupTensorboardSqliteDb(std::shared_ptr<Sqlite> db) { TF_RETURN_IF_ERROR(s.CreateRunsTable()); TF_RETURN_IF_ERROR(s.CreateExperimentsTable()); TF_RETURN_IF_ERROR(s.CreateUsersTable()); - TF_RETURN_IF_ERROR(s.CreateNodeDefsTable()); - TF_RETURN_IF_ERROR(s.CreateRunNodeDefsTable()); + TF_RETURN_IF_ERROR(s.CreateGraphsTable()); + TF_RETURN_IF_ERROR(s.CreateNodeInputsTable()); + TF_RETURN_IF_ERROR(s.CreateNodesTable()); TF_RETURN_IF_ERROR(s.CreateTensorIndex()); TF_RETURN_IF_ERROR(s.CreateTensorChunkIndex()); TF_RETURN_IF_ERROR(s.CreateTagIdIndex()); TF_RETURN_IF_ERROR(s.CreateRunIdIndex()); TF_RETURN_IF_ERROR(s.CreateExperimentIdIndex()); TF_RETURN_IF_ERROR(s.CreateUserIdIndex()); - TF_RETURN_IF_ERROR(s.CreateNodeDefIdIndex()); + TF_RETURN_IF_ERROR(s.CreateGraphIdIndex()); + TF_RETURN_IF_ERROR(s.CreateNodeIdIndex()); + TF_RETURN_IF_ERROR(s.CreateNodeInputsIndex()); TF_RETURN_IF_ERROR(s.CreateTagNameIndex()); TF_RETURN_IF_ERROR(s.CreateRunNameIndex()); TF_RETURN_IF_ERROR(s.CreateExperimentNameIndex()); TF_RETURN_IF_ERROR(s.CreateUserNameIndex()); TF_RETURN_IF_ERROR(s.CreateUserEmailIndex()); - TF_RETURN_IF_ERROR(s.CreateNodeDefFingerprintIndex()); - TF_RETURN_IF_ERROR(s.CreateRunNodeDefIndex()); + TF_RETURN_IF_ERROR(s.CreateNodeNameIndex()); return Status::OK(); } diff --git a/tensorflow/contrib/tensorboard/db/summary_db_writer.cc b/tensorflow/contrib/tensorboard/db/summary_db_writer.cc index a26ad61660..ae063d24ef 100644 --- a/tensorflow/contrib/tensorboard/db/summary_db_writer.cc +++ b/tensorflow/contrib/tensorboard/db/summary_db_writer.cc @@ -15,17 +15,29 @@ limitations under the License. #include "tensorflow/contrib/tensorboard/db/summary_db_writer.h" #include "tensorflow/contrib/tensorboard/db/schema.h" +#include "tensorflow/core/framework/graph.pb.h" +#include "tensorflow/core/framework/node_def.pb.h" #include "tensorflow/core/framework/summary.pb.h" +#include "tensorflow/core/lib/core/stringpiece.h" #include "tensorflow/core/lib/db/sqlite.h" #include "tensorflow/core/lib/random/random.h" #include "tensorflow/core/lib/strings/stringprintf.h" +#include "tensorflow/core/platform/fingerprint.h" #include "tensorflow/core/platform/snappy.h" #include "tensorflow/core/util/event.pb.h" namespace tensorflow { namespace { +double GetWallTime(Env* env) { + // TODO(@jart): Follow precise definitions for time laid out in schema. + // TODO(@jart): Use monotonic clock from gRPC codebase. + return static_cast<double>(env->NowMicros()) / 1.0e6; +} + int64 MakeRandomId() { + // TODO(@jart): Try generating ID in 2^24 space, falling back to 2^63 + // https://sqlite.org/src4/doc/trunk/www/varint.wiki int64 id = static_cast<int64>(random::New64() & ((1ULL << 63) - 1)); if (id == 0) { ++id; @@ -33,10 +45,201 @@ int64 MakeRandomId() { return id; } +Status Serialize(const protobuf::MessageLite& proto, string* output) { + output->clear(); + if (!proto.SerializeToString(output)) { + return errors::DataLoss("SerializeToString failed"); + } + return Status::OK(); +} + +Status Compress(const string& data, string* output) { + output->clear(); + if (!port::Snappy_Compress(data.data(), data.size(), output)) { + return errors::FailedPrecondition("TensorBase needs Snappy"); + } + return Status::OK(); +} + +Status BindProto(SqliteStatement* stmt, int parameter, + const protobuf::MessageLite& proto) { + string serialized; + TF_RETURN_IF_ERROR(Serialize(proto, &serialized)); + string compressed; + TF_RETURN_IF_ERROR(Compress(serialized, &compressed)); + stmt->BindBlobUnsafe(parameter, compressed); + return Status::OK(); +} + +Status BindTensor(SqliteStatement* stmt, int parameter, const Tensor& t) { + // TODO(@jart): Make portable between little and big endian systems. + // TODO(@jart): Use TensorChunks with minimal copying for big tensors. + // TODO(@jart): Add field to indicate encoding. + // TODO(@jart): Allow crunch tool to re-compress with zlib instead. + TensorProto p; + t.AsProtoTensorContent(&p); + return BindProto(stmt, parameter, p); +} + +class Transactor { + public: + explicit Transactor(std::shared_ptr<Sqlite> db) + : db_(std::move(db)), + begin_(db_->Prepare("BEGIN TRANSACTION")), + commit_(db_->Prepare("COMMIT TRANSACTION")), + rollback_(db_->Prepare("ROLLBACK TRANSACTION")) {} + + template <typename T, typename... Args> + Status Transact(T callback, Args&&... args) { + TF_RETURN_IF_ERROR(begin_.StepAndReset()); + Status s = callback(std::forward<Args>(args)...); + if (s.ok()) { + TF_RETURN_IF_ERROR(commit_.StepAndReset()); + } else { + TF_RETURN_WITH_CONTEXT_IF_ERROR(rollback_.StepAndReset(), s.ToString()); + } + return s; + } + + private: + std::shared_ptr<Sqlite> db_; + SqliteStatement begin_; + SqliteStatement commit_; + SqliteStatement rollback_; +}; + +class GraphSaver { + public: + static Status SaveToRun(Env* env, Sqlite* db, GraphDef* graph, int64 run_id) { + auto get = db->Prepare("SELECT graph_id FROM Runs WHERE run_id = ?"); + get.BindInt(1, run_id); + bool is_done; + TF_RETURN_IF_ERROR(get.Step(&is_done)); + int64 graph_id = is_done ? 0 : get.ColumnInt(0); + if (graph_id == 0) { + graph_id = MakeRandomId(); + // TODO(@jart): Check for ID collision. + auto set = db->Prepare("UPDATE Runs SET graph_id = ? WHERE run_id = ?"); + set.BindInt(1, graph_id); + set.BindInt(2, run_id); + TF_RETURN_IF_ERROR(set.StepAndReset()); + } + return Save(env, db, graph, graph_id); + } + + static Status Save(Env* env, Sqlite* db, GraphDef* graph, int64 graph_id) { + GraphSaver saver{env, db, graph, graph_id}; + saver.MapNameToNodeId(); + TF_RETURN_IF_ERROR(saver.SaveNodeInputs()); + TF_RETURN_IF_ERROR(saver.SaveNodes()); + TF_RETURN_IF_ERROR(saver.SaveGraph()); + return Status::OK(); + } + + private: + GraphSaver(Env* env, Sqlite* db, GraphDef* graph, int64 graph_id) + : env_(env), db_(db), graph_(graph), graph_id_(graph_id) {} + + void MapNameToNodeId() { + size_t toto = static_cast<size_t>(graph_->node_size()); + name_copies_.reserve(toto); + name_to_node_id_.reserve(toto); + for (int node_id = 0; node_id < graph_->node_size(); ++node_id) { + // Copy name into memory region, since we call clear_name() later. + // Then wrap in StringPiece so we can compare slices without copy. + name_copies_.emplace_back(graph_->node(node_id).name()); + name_to_node_id_.emplace(name_copies_.back(), node_id); + } + } + + Status SaveNodeInputs() { + auto purge = db_->Prepare("DELETE FROM NodeInputs WHERE graph_id = ?"); + purge.BindInt(1, graph_id_); + TF_RETURN_IF_ERROR(purge.StepAndReset()); + auto insert = db_->Prepare(R"sql( + INSERT INTO NodeInputs (graph_id, node_id, idx, input_node_id, is_control) + VALUES (?, ?, ?, ?, ?) + )sql"); + for (int node_id = 0; node_id < graph_->node_size(); ++node_id) { + const NodeDef& node = graph_->node(node_id); + for (int idx = 0; idx < node.input_size(); ++idx) { + StringPiece name = node.input(idx); + insert.BindInt(1, graph_id_); + insert.BindInt(2, node_id); + insert.BindInt(3, idx); + if (!name.empty() && name[0] == '^') { + name.remove_prefix(1); + insert.BindInt(5, 1); + } + auto e = name_to_node_id_.find(name); + if (e == name_to_node_id_.end()) { + return errors::DataLoss("Could not find node: ", name); + } + insert.BindInt(4, e->second); + TF_RETURN_WITH_CONTEXT_IF_ERROR(insert.StepAndReset(), node.name(), + " -> ", name); + } + } + return Status::OK(); + } + + Status SaveNodes() { + auto purge = db_->Prepare("DELETE FROM Nodes WHERE graph_id = ?"); + purge.BindInt(1, graph_id_); + TF_RETURN_IF_ERROR(purge.StepAndReset()); + auto insert = db_->Prepare(R"sql( + INSERT INTO Nodes (graph_id, node_id, node_name, op, device, node_def) + VALUES (?, ?, ?, ?, ?, ?) + )sql"); + for (int node_id = 0; node_id < graph_->node_size(); ++node_id) { + NodeDef* node = graph_->mutable_node(node_id); + insert.BindInt(1, graph_id_); + insert.BindInt(2, node_id); + insert.BindText(3, node->name()); + node->clear_name(); + if (!node->op().empty()) { + insert.BindText(4, node->op()); + node->clear_op(); + } + if (!node->device().empty()) { + insert.BindText(5, node->device()); + node->clear_device(); + } + node->clear_input(); + TF_RETURN_IF_ERROR(BindProto(&insert, 6, *node)); + TF_RETURN_WITH_CONTEXT_IF_ERROR(insert.StepAndReset(), node->name()); + } + return Status::OK(); + } + + Status SaveGraph() { + auto insert = db_->Prepare(R"sql( + INSERT OR REPLACE INTO Graphs (graph_id, inserted_time, graph_def) + VALUES (?, ?, ?) + )sql"); + insert.BindInt(1, graph_id_); + insert.BindDouble(2, GetWallTime(env_)); + graph_->clear_node(); + TF_RETURN_IF_ERROR(BindProto(&insert, 3, *graph_)); + return insert.StepAndReset(); + } + + Env* env_; + Sqlite* db_; + GraphDef* graph_; + int64 graph_id_; + std::vector<string> name_copies_; + std::unordered_map<StringPiece, int64, StringPiece::Hasher> name_to_node_id_; +}; + class SummaryDbWriter : public SummaryWriterInterface { public: SummaryDbWriter(Env* env, std::shared_ptr<Sqlite> db) - : SummaryWriterInterface(), env_(env), db_(std::move(db)), run_id_(-1) {} + : SummaryWriterInterface(), + env_(env), + db_(std::move(db)), + txn_(db_), + run_id_{0LL} {} ~SummaryDbWriter() override {} Status Initialize(const string& experiment_name, const string& run_name, @@ -76,7 +279,7 @@ class SummaryDbWriter : public SummaryWriterInterface { // TODO(@jart): Check for random ID collisions without needing txn retry. insert_tensor_.BindInt(1, tag_id); insert_tensor_.BindInt(2, global_step); - insert_tensor_.BindDouble(3, GetWallTime()); + insert_tensor_.BindDouble(3, GetWallTime(env_)); switch (t.dtype()) { case DT_INT64: insert_tensor_.BindInt(4, t.scalar<int64>()()); @@ -85,22 +288,41 @@ class SummaryDbWriter : public SummaryWriterInterface { insert_tensor_.BindDouble(4, t.scalar<double>()()); break; default: - TF_RETURN_IF_ERROR(BindTensor(t)); + TF_RETURN_IF_ERROR(BindTensor(&insert_tensor_, 4, t)); break; } return insert_tensor_.StepAndReset(); } - Status WriteEvent(std::unique_ptr<Event> e) override { + Status WriteGraph(int64 global_step, std::unique_ptr<GraphDef> g) override { mutex_lock ml(mu_); TF_RETURN_IF_ERROR(InitializeParents()); - if (e->what_case() == Event::WhatCase::kSummary) { - const Summary& summary = e->summary(); - for (int i = 0; i < summary.value_size(); ++i) { - TF_RETURN_IF_ERROR(WriteSummary(e.get(), summary.value(i))); + return txn_.Transact(GraphSaver::SaveToRun, env_, db_.get(), g.get(), + run_id_); + } + + Status WriteEvent(std::unique_ptr<Event> e) override { + switch (e->what_case()) { + case Event::WhatCase::kSummary: { + mutex_lock ml(mu_); + TF_RETURN_IF_ERROR(InitializeParents()); + const Summary& summary = e->summary(); + for (int i = 0; i < summary.value_size(); ++i) { + TF_RETURN_IF_ERROR(WriteSummary(e.get(), summary.value(i))); + } + return Status::OK(); } + case Event::WhatCase::kGraphDef: { + std::unique_ptr<GraphDef> graph{new GraphDef}; + if (!ParseProtoUnlimited(graph.get(), e->graph_def())) { + return errors::DataLoss("parse event.graph_def failed"); + } + return WriteGraph(e->step(), std::move(graph)); + } + default: + // TODO(@jart): Handle other stuff. + return Status::OK(); } - return Status::OK(); } Status WriteScalar(int64 global_step, Tensor t, const string& tag) override { @@ -136,33 +358,8 @@ class SummaryDbWriter : public SummaryWriterInterface { string DebugString() override { return "SummaryDbWriter"; } private: - double GetWallTime() { - // TODO(@jart): Follow precise definitions for time laid out in schema. - // TODO(@jart): Use monotonic clock from gRPC codebase. - return static_cast<double>(env_->NowMicros()) / 1.0e6; - } - - Status BindTensor(const Tensor& t) EXCLUSIVE_LOCKS_REQUIRED(mu_) { - // TODO(@jart): Make portable between little and big endian systems. - // TODO(@jart): Use TensorChunks with minimal copying for big tensors. - TensorProto p; - t.AsProtoTensorContent(&p); - string encoded; - if (!p.SerializeToString(&encoded)) { - return errors::DataLoss("SerializeToString failed"); - } - // TODO(@jart): Put byte at beginning of blob to indicate encoding. - // TODO(@jart): Allow crunch tool to re-compress with zlib instead. - string compressed; - if (!port::Snappy_Compress(encoded.data(), encoded.size(), &compressed)) { - return errors::FailedPrecondition("TensorBase needs Snappy"); - } - insert_tensor_.BindBlobUnsafe(4, compressed); - return Status::OK(); - } - Status InitializeParents() EXCLUSIVE_LOCKS_REQUIRED(mu_) { - if (run_id_ >= 0) { + if (run_id_ > 0) { return Status::OK(); } int64 user_id; @@ -195,7 +392,7 @@ class SummaryDbWriter : public SummaryWriterInterface { )sql"); insert_user.BindInt(1, *user_id); insert_user.BindText(2, user_name); - insert_user.BindDouble(3, GetWallTime()); + insert_user.BindDouble(3, GetWallTime(env_)); TF_RETURN_IF_ERROR(insert_user.StepAndReset()); } return Status::OK(); @@ -249,7 +446,7 @@ class SummaryDbWriter : public SummaryWriterInterface { } insert.BindInt(2, *id); insert.BindText(3, name); - insert.BindDouble(4, GetWallTime()); + insert.BindDouble(4, GetWallTime(env_)); TF_RETURN_IF_ERROR(insert.StepAndReset()); } return Status::OK(); @@ -276,6 +473,7 @@ class SummaryDbWriter : public SummaryWriterInterface { mutex mu_; Env* env_; std::shared_ptr<Sqlite> db_ GUARDED_BY(mu_); + Transactor txn_ GUARDED_BY(mu_); SqliteStatement insert_tensor_ GUARDED_BY(mu_); SqliteStatement update_metadata_ GUARDED_BY(mu_); string user_name_ GUARDED_BY(mu_); diff --git a/tensorflow/contrib/tensorboard/db/summary_db_writer_test.cc b/tensorflow/contrib/tensorboard/db/summary_db_writer_test.cc index c1af51e7b7..3431842ca2 100644 --- a/tensorflow/contrib/tensorboard/db/summary_db_writer_test.cc +++ b/tensorflow/contrib/tensorboard/db/summary_db_writer_test.cc @@ -14,6 +14,8 @@ limitations under the License. ==============================================================================*/ #include "tensorflow/contrib/tensorboard/db/summary_db_writer.h" +#include "tensorflow/core/framework/graph.pb.h" +#include "tensorflow/core/framework/node_def.pb.h" #include "tensorflow/core/framework/summary.pb.h" #include "tensorflow/core/lib/core/status_test_util.h" #include "tensorflow/core/lib/db/sqlite.h" @@ -212,5 +214,81 @@ TEST_F(SummaryDbWriterTest, WriteEvent_Scalar) { kTolerance); } +TEST_F(SummaryDbWriterTest, WriteGraph) { + TF_ASSERT_OK(CreateSummaryDbWriter(db_, "", "R", "", &env_, &writer_)); + env_.AdvanceByMillis(23); + GraphDef graph; + NodeDef* node = graph.add_node(); + node->set_name("x"); + node->set_op("Placeholder"); + node = graph.add_node(); + node->set_name("y"); + node->set_op("Placeholder"); + node = graph.add_node(); + node->set_name("z"); + node->set_op("Love"); + node = graph.add_node(); + node->set_name("+"); + node->set_op("Add"); + node->add_input("x"); + node->add_input("y"); + node->add_input("^z"); + node->set_device("tpu/lol"); + std::unique_ptr<Event> e{new Event}; + graph.SerializeToString(e->mutable_graph_def()); + TF_ASSERT_OK(writer_->WriteEvent(std::move(e))); + TF_ASSERT_OK(writer_->Flush()); + ASSERT_EQ(1LL, QueryInt("SELECT COUNT(*) FROM Runs")); + ASSERT_EQ(1LL, QueryInt("SELECT COUNT(*) FROM Graphs")); + ASSERT_EQ(4LL, QueryInt("SELECT COUNT(*) FROM Nodes")); + ASSERT_EQ(3LL, QueryInt("SELECT COUNT(*) FROM NodeInputs")); + + int64 graph_id = QueryInt("SELECT graph_id FROM Graphs"); + EXPECT_GT(graph_id, 0LL); + EXPECT_EQ(graph_id, QueryInt("SELECT graph_id FROM Runs")); + EXPECT_EQ(0.023, QueryDouble("SELECT inserted_time FROM Graphs")); + EXPECT_FALSE(QueryString("SELECT graph_def FROM Graphs").empty()); + + EXPECT_EQ("x", QueryString("SELECT node_name FROM Nodes WHERE node_id = 0")); + EXPECT_EQ("y", QueryString("SELECT node_name FROM Nodes WHERE node_id = 1")); + EXPECT_EQ("z", QueryString("SELECT node_name FROM Nodes WHERE node_id = 2")); + EXPECT_EQ("+", QueryString("SELECT node_name FROM Nodes WHERE node_id = 3")); + + EXPECT_EQ("Placeholder", + QueryString("SELECT op FROM Nodes WHERE node_id = 0")); + EXPECT_EQ("Placeholder", + QueryString("SELECT op FROM Nodes WHERE node_id = 1")); + EXPECT_EQ("Love", QueryString("SELECT op FROM Nodes WHERE node_id = 2")); + EXPECT_EQ("Add", QueryString("SELECT op FROM Nodes WHERE node_id = 3")); + + EXPECT_EQ("", QueryString("SELECT device FROM Nodes WHERE node_id = 0")); + EXPECT_EQ("", QueryString("SELECT device FROM Nodes WHERE node_id = 1")); + EXPECT_EQ("", QueryString("SELECT device FROM Nodes WHERE node_id = 2")); + EXPECT_EQ("tpu/lol", + QueryString("SELECT device FROM Nodes WHERE node_id = 3")); + + EXPECT_EQ(graph_id, + QueryInt("SELECT graph_id FROM NodeInputs WHERE idx = 0")); + EXPECT_EQ(graph_id, + QueryInt("SELECT graph_id FROM NodeInputs WHERE idx = 1")); + EXPECT_EQ(graph_id, + QueryInt("SELECT graph_id FROM NodeInputs WHERE idx = 2")); + + EXPECT_EQ(3LL, QueryInt("SELECT node_id FROM NodeInputs WHERE idx = 0")); + EXPECT_EQ(3LL, QueryInt("SELECT node_id FROM NodeInputs WHERE idx = 1")); + EXPECT_EQ(3LL, QueryInt("SELECT node_id FROM NodeInputs WHERE idx = 2")); + + EXPECT_EQ(0LL, + QueryInt("SELECT input_node_id FROM NodeInputs WHERE idx = 0")); + EXPECT_EQ(1LL, + QueryInt("SELECT input_node_id FROM NodeInputs WHERE idx = 1")); + EXPECT_EQ(2LL, + QueryInt("SELECT input_node_id FROM NodeInputs WHERE idx = 2")); + + EXPECT_EQ(0LL, QueryInt("SELECT is_control FROM NodeInputs WHERE idx = 0")); + EXPECT_EQ(0LL, QueryInt("SELECT is_control FROM NodeInputs WHERE idx = 1")); + EXPECT_EQ(1LL, QueryInt("SELECT is_control FROM NodeInputs WHERE idx = 2")); +} + } // namespace } // namespace tensorflow |