aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/contrib/tensorboard
diff options
context:
space:
mode:
authorGravatar Justine Tunney <jart@google.com>2017-11-08 10:55:48 -0800
committerGravatar Andrew Selle <aselle@andyselle.com>2017-11-10 16:14:36 -0800
commit35cc51dc2a716c4b92429db60238e4f15fba1ed3 (patch)
tree397908ffa876253ea4230a0c13e83775841b0201 /tensorflow/contrib/tensorboard
parent4a618e411af3f808eb0f65ce4f7151450f1f16a5 (diff)
Add database writer ops to contrib/summary
PiperOrigin-RevId: 175030602
Diffstat (limited to 'tensorflow/contrib/tensorboard')
-rw-r--r--tensorflow/contrib/tensorboard/db/BUILD2
-rw-r--r--tensorflow/contrib/tensorboard/db/summary_db_writer.cc34
-rw-r--r--tensorflow/contrib/tensorboard/db/summary_db_writer_test.cc56
3 files changed, 87 insertions, 5 deletions
diff --git a/tensorflow/contrib/tensorboard/db/BUILD b/tensorflow/contrib/tensorboard/db/BUILD
index d8bbf87d2c..068e862650 100644
--- a/tensorflow/contrib/tensorboard/db/BUILD
+++ b/tensorflow/contrib/tensorboard/db/BUILD
@@ -45,10 +45,12 @@ cc_library(
tf_cc_test(
name = "summary_db_writer_test",
+ size = "small",
srcs = ["summary_db_writer_test.cc"],
deps = [
":summary_db_writer",
"//tensorflow/core:lib",
+ "//tensorflow/core:protos_all_cc",
"//tensorflow/core:test",
"//tensorflow/core:test_main",
"//tensorflow/core/lib/db:sqlite",
diff --git a/tensorflow/contrib/tensorboard/db/summary_db_writer.cc b/tensorflow/contrib/tensorboard/db/summary_db_writer.cc
index df64e36305..a26ad61660 100644
--- a/tensorflow/contrib/tensorboard/db/summary_db_writer.cc
+++ b/tensorflow/contrib/tensorboard/db/summary_db_writer.cc
@@ -15,10 +15,12 @@ limitations under the License.
#include "tensorflow/contrib/tensorboard/db/summary_db_writer.h"
#include "tensorflow/contrib/tensorboard/db/schema.h"
+#include "tensorflow/core/framework/summary.pb.h"
#include "tensorflow/core/lib/db/sqlite.h"
#include "tensorflow/core/lib/random/random.h"
#include "tensorflow/core/lib/strings/stringprintf.h"
#include "tensorflow/core/platform/snappy.h"
+#include "tensorflow/core/util/event.pb.h"
namespace tensorflow {
namespace {
@@ -86,13 +88,19 @@ class SummaryDbWriter : public SummaryWriterInterface {
TF_RETURN_IF_ERROR(BindTensor(t));
break;
}
- TF_RETURN_IF_ERROR(insert_tensor_.StepAndReset());
- return Status::OK();
+ return insert_tensor_.StepAndReset();
}
Status WriteEvent(std::unique_ptr<Event> e) override {
- // TODO(@jart): This will be used to load event logs.
- return errors::Unimplemented("WriteEvent");
+ mutex_lock ml(mu_);
+ TF_RETURN_IF_ERROR(InitializeParents());
+ if (e->what_case() == Event::WhatCase::kSummary) {
+ const Summary& summary = e->summary();
+ for (int i = 0; i < summary.value_size(); ++i) {
+ TF_RETURN_IF_ERROR(WriteSummary(e.get(), summary.value(i)));
+ }
+ }
+ return Status::OK();
}
Status WriteScalar(int64 global_step, Tensor t, const string& tag) override {
@@ -247,6 +255,24 @@ class SummaryDbWriter : public SummaryWriterInterface {
return Status::OK();
}
+ Status WriteSummary(const Event* e, const Summary::Value& summary)
+ EXCLUSIVE_LOCKS_REQUIRED(mu_) {
+ int64 tag_id;
+ TF_RETURN_IF_ERROR(GetTagId(run_id_, summary.tag(), &tag_id));
+ insert_tensor_.BindInt(1, tag_id);
+ insert_tensor_.BindInt(2, e->step());
+ insert_tensor_.BindDouble(3, e->wall_time());
+ switch (summary.value_case()) {
+ case Summary::Value::ValueCase::kSimpleValue:
+ insert_tensor_.BindDouble(4, summary.simple_value());
+ break;
+ default:
+ // TODO(@jart): Handle the rest.
+ return Status::OK();
+ }
+ return insert_tensor_.StepAndReset();
+ }
+
mutex mu_;
Env* env_;
std::shared_ptr<Sqlite> db_ GUARDED_BY(mu_);
diff --git a/tensorflow/contrib/tensorboard/db/summary_db_writer_test.cc b/tensorflow/contrib/tensorboard/db/summary_db_writer_test.cc
index d32904f97c..c1af51e7b7 100644
--- a/tensorflow/contrib/tensorboard/db/summary_db_writer_test.cc
+++ b/tensorflow/contrib/tensorboard/db/summary_db_writer_test.cc
@@ -14,14 +14,19 @@ limitations under the License.
==============================================================================*/
#include "tensorflow/contrib/tensorboard/db/summary_db_writer.h"
+#include "tensorflow/core/framework/summary.pb.h"
#include "tensorflow/core/lib/core/status_test_util.h"
#include "tensorflow/core/lib/db/sqlite.h"
+#include "tensorflow/core/lib/strings/strcat.h"
#include "tensorflow/core/platform/env.h"
#include "tensorflow/core/platform/test.h"
+#include "tensorflow/core/util/event.pb.h"
namespace tensorflow {
namespace {
+const float kTolerance = 1e-5;
+
Tensor MakeScalarInt64(int64 x) {
Tensor t(DT_INT64, TensorShape({}));
t.scalar<int64>()() = x;
@@ -41,7 +46,7 @@ class FakeClockEnv : public EnvWrapper {
class SummaryDbWriterTest : public ::testing::Test {
protected:
- void SetUp() override { db_ = Sqlite::Open("file::memory:").ValueOrDie(); }
+ void SetUp() override { db_ = Sqlite::Open(":memory:").ValueOrDie(); }
void TearDown() override {
if (writer_ != nullptr) {
@@ -158,5 +163,54 @@ TEST_F(SummaryDbWriterTest, TensorsWritten_RowsGetInitialized) {
QueryString("SELECT tensor FROM Tensors WHERE step = 2").empty());
}
+TEST_F(SummaryDbWriterTest, EmptyParentNames_NoParentsCreated) {
+ TF_ASSERT_OK(CreateSummaryDbWriter(db_, "", "", "", &env_, &writer_));
+ TF_ASSERT_OK(writer_->WriteTensor(1, MakeScalarInt64(123LL), "taggy",
+ "this-is-metaaa"));
+ TF_ASSERT_OK(writer_->Flush());
+ ASSERT_EQ(0LL, QueryInt("SELECT COUNT(*) FROM Users"));
+ ASSERT_EQ(0LL, QueryInt("SELECT COUNT(*) FROM Experiments"));
+ ASSERT_EQ(0LL, QueryInt("SELECT COUNT(*) FROM Runs"));
+ ASSERT_EQ(1LL, QueryInt("SELECT COUNT(*) FROM Tags"));
+ ASSERT_EQ(1LL, QueryInt("SELECT COUNT(*) FROM Tensors"));
+}
+
+TEST_F(SummaryDbWriterTest, WriteEvent_Scalar) {
+ TF_ASSERT_OK(CreateSummaryDbWriter(db_, "", "", "", &env_, &writer_));
+ std::unique_ptr<Event> e{new Event};
+ e->set_step(7);
+ e->set_wall_time(123.456);
+ Summary::Value* s = e->mutable_summary()->add_value();
+ s->set_tag("π");
+ s->set_simple_value(3.14f);
+ s = e->mutable_summary()->add_value();
+ s->set_tag("φ");
+ s->set_simple_value(1.61f);
+ TF_ASSERT_OK(writer_->WriteEvent(std::move(e)));
+ TF_ASSERT_OK(writer_->Flush());
+ ASSERT_EQ(2LL, QueryInt("SELECT COUNT(*) FROM Tags"));
+ ASSERT_EQ(2LL, QueryInt("SELECT COUNT(*) FROM Tensors"));
+ int64 tag1_id = QueryInt("SELECT tag_id FROM Tags WHERE tag_name = 'π'");
+ int64 tag2_id = QueryInt("SELECT tag_id FROM Tags WHERE tag_name = 'φ'");
+ EXPECT_GT(tag1_id, 0LL);
+ EXPECT_GT(tag2_id, 0LL);
+ EXPECT_EQ(123.456, QueryDouble(strings::StrCat(
+ "SELECT computed_time FROM Tensors WHERE tag_id = ",
+ tag1_id, " AND step = 7")));
+ EXPECT_EQ(123.456, QueryDouble(strings::StrCat(
+ "SELECT computed_time FROM Tensors WHERE tag_id = ",
+ tag2_id, " AND step = 7")));
+ EXPECT_NEAR(3.14,
+ QueryDouble(strings::StrCat(
+ "SELECT tensor FROM Tensors WHERE tag_id = ", tag1_id,
+ " AND step = 7")),
+ kTolerance); // Summary::simple_value is float
+ EXPECT_NEAR(1.61,
+ QueryDouble(strings::StrCat(
+ "SELECT tensor FROM Tensors WHERE tag_id = ", tag2_id,
+ " AND step = 7")),
+ kTolerance);
+}
+
} // namespace
} // namespace tensorflow