diff options
author | Justine Tunney <jart@google.com> | 2017-11-08 10:55:48 -0800 |
---|---|---|
committer | Andrew Selle <aselle@andyselle.com> | 2017-11-10 16:14:36 -0800 |
commit | 35cc51dc2a716c4b92429db60238e4f15fba1ed3 (patch) | |
tree | 397908ffa876253ea4230a0c13e83775841b0201 /tensorflow/contrib/tensorboard | |
parent | 4a618e411af3f808eb0f65ce4f7151450f1f16a5 (diff) |
Add database writer ops to contrib/summary
PiperOrigin-RevId: 175030602
Diffstat (limited to 'tensorflow/contrib/tensorboard')
-rw-r--r-- | tensorflow/contrib/tensorboard/db/BUILD | 2 | ||||
-rw-r--r-- | tensorflow/contrib/tensorboard/db/summary_db_writer.cc | 34 | ||||
-rw-r--r-- | tensorflow/contrib/tensorboard/db/summary_db_writer_test.cc | 56 |
3 files changed, 87 insertions, 5 deletions
diff --git a/tensorflow/contrib/tensorboard/db/BUILD b/tensorflow/contrib/tensorboard/db/BUILD index d8bbf87d2c..068e862650 100644 --- a/tensorflow/contrib/tensorboard/db/BUILD +++ b/tensorflow/contrib/tensorboard/db/BUILD @@ -45,10 +45,12 @@ cc_library( tf_cc_test( name = "summary_db_writer_test", + size = "small", srcs = ["summary_db_writer_test.cc"], deps = [ ":summary_db_writer", "//tensorflow/core:lib", + "//tensorflow/core:protos_all_cc", "//tensorflow/core:test", "//tensorflow/core:test_main", "//tensorflow/core/lib/db:sqlite", diff --git a/tensorflow/contrib/tensorboard/db/summary_db_writer.cc b/tensorflow/contrib/tensorboard/db/summary_db_writer.cc index df64e36305..a26ad61660 100644 --- a/tensorflow/contrib/tensorboard/db/summary_db_writer.cc +++ b/tensorflow/contrib/tensorboard/db/summary_db_writer.cc @@ -15,10 +15,12 @@ limitations under the License. #include "tensorflow/contrib/tensorboard/db/summary_db_writer.h" #include "tensorflow/contrib/tensorboard/db/schema.h" +#include "tensorflow/core/framework/summary.pb.h" #include "tensorflow/core/lib/db/sqlite.h" #include "tensorflow/core/lib/random/random.h" #include "tensorflow/core/lib/strings/stringprintf.h" #include "tensorflow/core/platform/snappy.h" +#include "tensorflow/core/util/event.pb.h" namespace tensorflow { namespace { @@ -86,13 +88,19 @@ class SummaryDbWriter : public SummaryWriterInterface { TF_RETURN_IF_ERROR(BindTensor(t)); break; } - TF_RETURN_IF_ERROR(insert_tensor_.StepAndReset()); - return Status::OK(); + return insert_tensor_.StepAndReset(); } Status WriteEvent(std::unique_ptr<Event> e) override { - // TODO(@jart): This will be used to load event logs. - return errors::Unimplemented("WriteEvent"); + mutex_lock ml(mu_); + TF_RETURN_IF_ERROR(InitializeParents()); + if (e->what_case() == Event::WhatCase::kSummary) { + const Summary& summary = e->summary(); + for (int i = 0; i < summary.value_size(); ++i) { + TF_RETURN_IF_ERROR(WriteSummary(e.get(), summary.value(i))); + } + } + return Status::OK(); } Status WriteScalar(int64 global_step, Tensor t, const string& tag) override { @@ -247,6 +255,24 @@ class SummaryDbWriter : public SummaryWriterInterface { return Status::OK(); } + Status WriteSummary(const Event* e, const Summary::Value& summary) + EXCLUSIVE_LOCKS_REQUIRED(mu_) { + int64 tag_id; + TF_RETURN_IF_ERROR(GetTagId(run_id_, summary.tag(), &tag_id)); + insert_tensor_.BindInt(1, tag_id); + insert_tensor_.BindInt(2, e->step()); + insert_tensor_.BindDouble(3, e->wall_time()); + switch (summary.value_case()) { + case Summary::Value::ValueCase::kSimpleValue: + insert_tensor_.BindDouble(4, summary.simple_value()); + break; + default: + // TODO(@jart): Handle the rest. + return Status::OK(); + } + return insert_tensor_.StepAndReset(); + } + mutex mu_; Env* env_; std::shared_ptr<Sqlite> db_ GUARDED_BY(mu_); diff --git a/tensorflow/contrib/tensorboard/db/summary_db_writer_test.cc b/tensorflow/contrib/tensorboard/db/summary_db_writer_test.cc index d32904f97c..c1af51e7b7 100644 --- a/tensorflow/contrib/tensorboard/db/summary_db_writer_test.cc +++ b/tensorflow/contrib/tensorboard/db/summary_db_writer_test.cc @@ -14,14 +14,19 @@ limitations under the License. ==============================================================================*/ #include "tensorflow/contrib/tensorboard/db/summary_db_writer.h" +#include "tensorflow/core/framework/summary.pb.h" #include "tensorflow/core/lib/core/status_test_util.h" #include "tensorflow/core/lib/db/sqlite.h" +#include "tensorflow/core/lib/strings/strcat.h" #include "tensorflow/core/platform/env.h" #include "tensorflow/core/platform/test.h" +#include "tensorflow/core/util/event.pb.h" namespace tensorflow { namespace { +const float kTolerance = 1e-5; + Tensor MakeScalarInt64(int64 x) { Tensor t(DT_INT64, TensorShape({})); t.scalar<int64>()() = x; @@ -41,7 +46,7 @@ class FakeClockEnv : public EnvWrapper { class SummaryDbWriterTest : public ::testing::Test { protected: - void SetUp() override { db_ = Sqlite::Open("file::memory:").ValueOrDie(); } + void SetUp() override { db_ = Sqlite::Open(":memory:").ValueOrDie(); } void TearDown() override { if (writer_ != nullptr) { @@ -158,5 +163,54 @@ TEST_F(SummaryDbWriterTest, TensorsWritten_RowsGetInitialized) { QueryString("SELECT tensor FROM Tensors WHERE step = 2").empty()); } +TEST_F(SummaryDbWriterTest, EmptyParentNames_NoParentsCreated) { + TF_ASSERT_OK(CreateSummaryDbWriter(db_, "", "", "", &env_, &writer_)); + TF_ASSERT_OK(writer_->WriteTensor(1, MakeScalarInt64(123LL), "taggy", + "this-is-metaaa")); + TF_ASSERT_OK(writer_->Flush()); + ASSERT_EQ(0LL, QueryInt("SELECT COUNT(*) FROM Users")); + ASSERT_EQ(0LL, QueryInt("SELECT COUNT(*) FROM Experiments")); + ASSERT_EQ(0LL, QueryInt("SELECT COUNT(*) FROM Runs")); + ASSERT_EQ(1LL, QueryInt("SELECT COUNT(*) FROM Tags")); + ASSERT_EQ(1LL, QueryInt("SELECT COUNT(*) FROM Tensors")); +} + +TEST_F(SummaryDbWriterTest, WriteEvent_Scalar) { + TF_ASSERT_OK(CreateSummaryDbWriter(db_, "", "", "", &env_, &writer_)); + std::unique_ptr<Event> e{new Event}; + e->set_step(7); + e->set_wall_time(123.456); + Summary::Value* s = e->mutable_summary()->add_value(); + s->set_tag("π"); + s->set_simple_value(3.14f); + s = e->mutable_summary()->add_value(); + s->set_tag("φ"); + s->set_simple_value(1.61f); + TF_ASSERT_OK(writer_->WriteEvent(std::move(e))); + TF_ASSERT_OK(writer_->Flush()); + ASSERT_EQ(2LL, QueryInt("SELECT COUNT(*) FROM Tags")); + ASSERT_EQ(2LL, QueryInt("SELECT COUNT(*) FROM Tensors")); + int64 tag1_id = QueryInt("SELECT tag_id FROM Tags WHERE tag_name = 'π'"); + int64 tag2_id = QueryInt("SELECT tag_id FROM Tags WHERE tag_name = 'φ'"); + EXPECT_GT(tag1_id, 0LL); + EXPECT_GT(tag2_id, 0LL); + EXPECT_EQ(123.456, QueryDouble(strings::StrCat( + "SELECT computed_time FROM Tensors WHERE tag_id = ", + tag1_id, " AND step = 7"))); + EXPECT_EQ(123.456, QueryDouble(strings::StrCat( + "SELECT computed_time FROM Tensors WHERE tag_id = ", + tag2_id, " AND step = 7"))); + EXPECT_NEAR(3.14, + QueryDouble(strings::StrCat( + "SELECT tensor FROM Tensors WHERE tag_id = ", tag1_id, + " AND step = 7")), + kTolerance); // Summary::simple_value is float + EXPECT_NEAR(1.61, + QueryDouble(strings::StrCat( + "SELECT tensor FROM Tensors WHERE tag_id = ", tag2_id, + " AND step = 7")), + kTolerance); +} + } // namespace } // namespace tensorflow |