diff options
author | 2018-05-22 15:30:02 -0700 | |
---|---|---|
committer | 2018-05-22 15:32:36 -0700 | |
commit | b7d3d31a78ce90fd9733d67247ae34c694199d19 (patch) | |
tree | 7871818397a4a63c4f891c753fafa243977b7e99 /tensorflow/contrib/tensorboard | |
parent | 17272b4d1ccb5c7bd0bc3015c34f8bd769516354 (diff) |
Remove reservoir sampling from SummaryDbWriter
PiperOrigin-RevId: 197634162
Diffstat (limited to 'tensorflow/contrib/tensorboard')
-rw-r--r-- | tensorflow/contrib/tensorboard/db/summary_db_writer.cc | 153 | ||||
-rw-r--r-- | tensorflow/contrib/tensorboard/db/summary_db_writer_test.cc | 6 |
2 files changed, 27 insertions, 132 deletions
diff --git a/tensorflow/contrib/tensorboard/db/summary_db_writer.cc b/tensorflow/contrib/tensorboard/db/summary_db_writer.cc index 6590d6f7df..d5d8e4100f 100644 --- a/tensorflow/contrib/tensorboard/db/summary_db_writer.cc +++ b/tensorflow/contrib/tensorboard/db/summary_db_writer.cc @@ -14,6 +14,8 @@ limitations under the License. ==============================================================================*/ #include "tensorflow/contrib/tensorboard/db/summary_db_writer.h" +#include <deque> + #include "tensorflow/contrib/tensorboard/db/summary_converter.h" #include "tensorflow/core/framework/graph.pb.h" #include "tensorflow/core/framework/node_def.pb.h" @@ -66,14 +68,9 @@ const char* kImagePluginName = "images"; const char* kAudioPluginName = "audio"; const char* kHistogramPluginName = "histograms"; -const int kScalarSlots = 10000; -const int kImageSlots = 10; -const int kAudioSlots = 10; -const int kHistogramSlots = 1; -const int kTensorSlots = 10; - const int64 kReserveMinBytes = 32; const double kReserveMultiplier = 1.5; +const int64 kPreallocateRows = 1000; // Flush is a misnomer because what we're actually doing is having lots // of commits inside any SqliteTransaction that writes potentially @@ -139,22 +136,6 @@ void PatchPluginName(SummaryMetadata* metadata, const char* name) { } } -int GetSlots(const Tensor& t, const SummaryMetadata& metadata) { - if (metadata.plugin_data().plugin_name() == kScalarPluginName) { - return kScalarSlots; - } else if (metadata.plugin_data().plugin_name() == kImagePluginName) { - return kImageSlots; - } else if (metadata.plugin_data().plugin_name() == kAudioPluginName) { - return kAudioSlots; - } else if (metadata.plugin_data().plugin_name() == kHistogramPluginName) { - return kHistogramSlots; - } else if (t.dims() == 0 && t.dtype() != DT_STRING) { - return kScalarSlots; - } else { - return kTensorSlots; - } -} - Status SetDescription(Sqlite* db, int64 id, const StringPiece& markdown) { const char* sql = R"sql( INSERT OR REPLACE INTO Descriptions (id, description) VALUES (?, ?) @@ -481,24 +462,6 @@ class RunMetadata { return insert.StepAndReset(); } - Status GetIsWatching(Sqlite* db, bool* is_watching) - SQLITE_TRANSACTIONS_EXCLUDED(*db) LOCKS_EXCLUDED(mu_) { - mutex_lock lock(mu_); - if (experiment_id_ == kAbsent) { - *is_watching = true; - return Status::OK(); - } - const char* sql = R"sql( - SELECT is_watching FROM Experiments WHERE experiment_id = ? - )sql"; - SqliteStatement stmt; - TF_RETURN_IF_ERROR(db->Prepare(sql, &stmt)); - stmt.BindInt(1, experiment_id_); - TF_RETURN_IF_ERROR(stmt.StepOnce()); - *is_watching = stmt.ColumnInt(0) != 0; - return Status::OK(); - } - private: Status InitializeUser(Sqlite* db, uint64 now) EXCLUSIVE_LOCKS_REQUIRED(mu_) { if (user_id_ != kAbsent || user_name_.empty()) return Status::OK(); @@ -659,43 +622,15 @@ class RunMetadata { /// \brief Tensor writer for a single series, e.g. Tag. /// -/// This class can be used to write an infinite stream of Tensors to the -/// database in a fixed block of contiguous disk space. This is -/// accomplished using Algorithm R reservoir sampling. -/// -/// The reservoir consists of a fixed number of rows, which are inserted -/// using ZEROBLOB upon receiving the first sample, which is used to -/// predict how big the other ones are likely to be. This is done -/// transactionally in a way that tries to be mindful of other processes -/// that might be trying to access the same DB. -/// -/// Once the reservoir fills up, rows are replaced at random, and writes -/// gradually become no-ops. This allows long training to go fast -/// without configuration. The exception is when someone is actually -/// looking at TensorBoard. When that happens, the "keep last" behavior -/// is turned on and Append() will always result in a write. -/// -/// If no one is watching training, this class still holds on to the -/// most recent "dangling" Tensor, so if Finish() is called, the most -/// recent training state can be written to disk. -/// -/// The randomly selected sampling points should be consistent across -/// multiple instances. -/// /// This class is thread safe. class SeriesWriter { public: - SeriesWriter(int64 series, int slots, RunMetadata* meta) - : series_{series}, - slots_{slots}, - meta_{meta}, - rng_{std::mt19937_64::default_seed} { + SeriesWriter(int64 series, RunMetadata* meta) : series_{series}, meta_{meta} { DCHECK(series_ > 0); - DCHECK(slots_ > 0); } Status Append(Sqlite* db, int64 step, uint64 now, double computed_time, - Tensor t) SQLITE_TRANSACTIONS_EXCLUDED(*db) + const Tensor& t) SQLITE_TRANSACTIONS_EXCLUDED(*db) LOCKS_EXCLUDED(mu_) { mutex_lock lock(mu_); if (rowids_.empty()) { @@ -705,41 +640,20 @@ class SeriesWriter { return s; } } - DCHECK(rowids_.size() == slots_); - int64 rowid; - size_t i = count_; - if (i < slots_) { - rowid = last_rowid_ = rowids_[i]; - } else { - i = rng_() % (i + 1); - if (i < slots_) { - rowid = last_rowid_ = rowids_[i]; - } else { - bool keep_last; - TF_RETURN_IF_ERROR(meta_->GetIsWatching(db, &keep_last)); - if (!keep_last) { - ++count_; - dangling_tensor_.reset(new Tensor(std::move(t))); - dangling_step_ = step; - dangling_computed_time_ = computed_time; - return Status::OK(); - } - rowid = last_rowid_; - } - } + int64 rowid = rowids_.front(); Status s = Write(db, rowid, step, computed_time, t); if (s.ok()) { ++count_; - dangling_tensor_.reset(); } + rowids_.pop_front(); return s; } Status Finish(Sqlite* db) SQLITE_TRANSACTIONS_EXCLUDED(*db) LOCKS_EXCLUDED(mu_) { mutex_lock lock(mu_); - // Short runs: Delete unused pre-allocated Tensors. - if (count_ < rowids_.size()) { + // Delete unused pre-allocated Tensors. + if (!rowids_.empty()) { SqliteTransaction txn(*db); const char* sql = R"sql( DELETE FROM Tensors WHERE rowid = ? @@ -747,19 +661,13 @@ class SeriesWriter { SqliteStatement deleter; TF_RETURN_IF_ERROR(db->Prepare(sql, &deleter)); for (size_t i = count_; i < rowids_.size(); ++i) { - deleter.BindInt(1, rowids_[i]); + deleter.BindInt(1, rowids_.front()); TF_RETURN_IF_ERROR(deleter.StepAndReset()); + rowids_.pop_front(); } TF_RETURN_IF_ERROR(txn.Commit()); rowids_.clear(); } - // Long runs: Make last sample be the very most recent one. - if (dangling_tensor_) { - DCHECK(last_rowid_ != kAbsent); - TF_RETURN_IF_ERROR(Write(db, last_rowid_, dangling_step_, - dangling_computed_time_, *dangling_tensor_)); - dangling_tensor_.reset(); - } return Status::OK(); } @@ -783,7 +691,6 @@ class SeriesWriter { Status Update(Sqlite* db, int64 step, double computed_time, const Tensor& t, const StringPiece& data, int64 rowid) { - // TODO(jart): How can we ensure reservoir fills on replace? const char* sql = R"sql( UPDATE OR REPLACE Tensors @@ -878,7 +785,7 @@ class SeriesWriter { // TODO(jart): Maybe preallocate index pages by setting step. This // is tricky because UPDATE OR REPLACE can have a side // effect of deleting preallocated rows. - for (int64 i = 0; i < slots_; ++i) { + for (int64 i = 0; i < kPreallocateRows; ++i) { insert.BindInt(1, series_); insert.BindInt(2, reserved_bytes); TF_RETURN_WITH_CONTEXT_IF_ERROR(insert.StepAndReset(), "i=", i); @@ -902,16 +809,10 @@ class SeriesWriter { mutex mu_; const int64 series_; - const int slots_; RunMetadata* const meta_; - std::mt19937_64 rng_ GUARDED_BY(mu_); uint64 count_ GUARDED_BY(mu_) = 0; - int64 last_rowid_ GUARDED_BY(mu_) = kAbsent; - std::vector<int64> rowids_ GUARDED_BY(mu_); + std::deque<int64> rowids_ GUARDED_BY(mu_); uint64 unflushed_bytes_ GUARDED_BY(mu_) = 0; - std::unique_ptr<Tensor> dangling_tensor_ GUARDED_BY(mu_); - int64 dangling_step_ GUARDED_BY(mu_) = 0; - double dangling_computed_time_ GUARDED_BY(mu_) = 0.0; TF_DISALLOW_COPY_AND_ASSIGN(SeriesWriter); }; @@ -928,10 +829,10 @@ class RunWriter { explicit RunWriter(RunMetadata* meta) : meta_{meta} {} Status Append(Sqlite* db, int64 tag_id, int64 step, uint64 now, - double computed_time, Tensor t, int slots) + double computed_time, const Tensor& t) SQLITE_TRANSACTIONS_EXCLUDED(*db) LOCKS_EXCLUDED(mu_) { - SeriesWriter* writer = GetSeriesWriter(tag_id, slots); - return writer->Append(db, step, now, computed_time, std::move(t)); + SeriesWriter* writer = GetSeriesWriter(tag_id); + return writer->Append(db, step, now, computed_time, t); } Status Finish(Sqlite* db) SQLITE_TRANSACTIONS_EXCLUDED(*db) @@ -948,11 +849,11 @@ class RunWriter { } private: - SeriesWriter* GetSeriesWriter(int64 tag_id, int slots) LOCKS_EXCLUDED(mu_) { + SeriesWriter* GetSeriesWriter(int64 tag_id) LOCKS_EXCLUDED(mu_) { mutex_lock sl(mu_); auto spot = series_writers_.find(tag_id); if (spot == series_writers_.end()) { - SeriesWriter* writer = new SeriesWriter(tag_id, slots, meta_); + SeriesWriter* writer = new SeriesWriter(tag_id, meta_); series_writers_[tag_id].reset(writer); return writer; } else { @@ -1082,8 +983,7 @@ class SummaryDbWriter : public SummaryWriterInterface { TF_RETURN_IF_ERROR( meta_.GetTagId(db_, now, computed_time, tag, &tag_id, metadata)); TF_RETURN_WITH_CONTEXT_IF_ERROR( - run_.Append(db_, tag_id, step, now, computed_time, t, - GetSlots(t, metadata)), + run_.Append(db_, tag_id, step, now, computed_time, t), meta_.user_name(), "/", meta_.experiment_name(), "/", meta_.run_name(), "/", tag, "@", step); return Status::OK(); @@ -1155,8 +1055,7 @@ class SummaryDbWriter : public SummaryWriterInterface { int64 tag_id; TF_RETURN_IF_ERROR(meta_.GetTagId(db_, now, e->wall_time(), s->tag(), &tag_id, s->metadata())); - return run_.Append(db_, tag_id, e->step(), now, e->wall_time(), t, - GetSlots(t, s->metadata())); + return run_.Append(db_, tag_id, e->step(), now, e->wall_time(), t); } // TODO(jart): Refactor Summary -> Tensor logic into separate file. @@ -1169,8 +1068,7 @@ class SummaryDbWriter : public SummaryWriterInterface { PatchPluginName(s->mutable_metadata(), kScalarPluginName); TF_RETURN_IF_ERROR(meta_.GetTagId(db_, now, e->wall_time(), s->tag(), &tag_id, s->metadata())); - return run_.Append(db_, tag_id, e->step(), now, e->wall_time(), - std::move(t), kScalarSlots); + return run_.Append(db_, tag_id, e->step(), now, e->wall_time(), t); } Status MigrateHistogram(const Event* e, Summary::Value* s, uint64 now) { @@ -1195,8 +1093,7 @@ class SummaryDbWriter : public SummaryWriterInterface { PatchPluginName(s->mutable_metadata(), kHistogramPluginName); TF_RETURN_IF_ERROR(meta_.GetTagId(db_, now, e->wall_time(), s->tag(), &tag_id, s->metadata())); - return run_.Append(db_, tag_id, e->step(), now, e->wall_time(), - std::move(t), kHistogramSlots); + return run_.Append(db_, tag_id, e->step(), now, e->wall_time(), t); } Status MigrateImage(const Event* e, Summary::Value* s, uint64 now) { @@ -1210,8 +1107,7 @@ class SummaryDbWriter : public SummaryWriterInterface { PatchPluginName(s->mutable_metadata(), kImagePluginName); TF_RETURN_IF_ERROR(meta_.GetTagId(db_, now, e->wall_time(), s->tag(), &tag_id, s->metadata())); - return run_.Append(db_, tag_id, e->step(), now, e->wall_time(), - std::move(t), kImageSlots); + return run_.Append(db_, tag_id, e->step(), now, e->wall_time(), t); } Status MigrateAudio(const Event* e, Summary::Value* s, uint64 now) { @@ -1224,8 +1120,7 @@ class SummaryDbWriter : public SummaryWriterInterface { PatchPluginName(s->mutable_metadata(), kAudioPluginName); TF_RETURN_IF_ERROR(meta_.GetTagId(db_, now, e->wall_time(), s->tag(), &tag_id, s->metadata())); - return run_.Append(db_, tag_id, e->step(), now, e->wall_time(), - std::move(t), kAudioSlots); + return run_.Append(db_, tag_id, e->step(), now, e->wall_time(), t); } Env* const env_; diff --git a/tensorflow/contrib/tensorboard/db/summary_db_writer_test.cc b/tensorflow/contrib/tensorboard/db/summary_db_writer_test.cc index 29b8063218..c34b6763a1 100644 --- a/tensorflow/contrib/tensorboard/db/summary_db_writer_test.cc +++ b/tensorflow/contrib/tensorboard/db/summary_db_writer_test.cc @@ -139,7 +139,7 @@ TEST_F(SummaryDbWriterTest, TensorsWritten_RowsGetInitialized) { ASSERT_EQ(1LL, QueryInt("SELECT COUNT(*) FROM Experiments")); ASSERT_EQ(1LL, QueryInt("SELECT COUNT(*) FROM Runs")); ASSERT_EQ(1LL, QueryInt("SELECT COUNT(*) FROM Tags")); - ASSERT_EQ(10000LL, QueryInt("SELECT COUNT(*) FROM Tensors")); + ASSERT_EQ(1000LL, QueryInt("SELECT COUNT(*) FROM Tensors")); int64 user_id = QueryInt("SELECT user_id FROM Users"); int64 experiment_id = QueryInt("SELECT experiment_id FROM Experiments"); @@ -188,7 +188,7 @@ TEST_F(SummaryDbWriterTest, EmptyParentNames_NoParentsCreated) { ASSERT_EQ(0LL, QueryInt("SELECT COUNT(*) FROM Experiments")); ASSERT_EQ(0LL, QueryInt("SELECT COUNT(*) FROM Runs")); ASSERT_EQ(1LL, QueryInt("SELECT COUNT(*) FROM Tags")); - ASSERT_EQ(10000LL, QueryInt("SELECT COUNT(*) FROM Tensors")); + ASSERT_EQ(1000LL, QueryInt("SELECT COUNT(*) FROM Tensors")); } TEST_F(SummaryDbWriterTest, WriteEvent_Scalar) { @@ -205,7 +205,7 @@ TEST_F(SummaryDbWriterTest, WriteEvent_Scalar) { TF_ASSERT_OK(writer_->WriteEvent(std::move(e))); TF_ASSERT_OK(writer_->Flush()); ASSERT_EQ(2LL, QueryInt("SELECT COUNT(*) FROM Tags")); - ASSERT_EQ(20000LL, QueryInt("SELECT COUNT(*) FROM Tensors")); + ASSERT_EQ(2000LL, QueryInt("SELECT COUNT(*) FROM Tensors")); int64 tag1_id = QueryInt("SELECT tag_id FROM Tags WHERE tag_name = 'π'"); int64 tag2_id = QueryInt("SELECT tag_id FROM Tags WHERE tag_name = 'φ'"); EXPECT_GT(tag1_id, 0LL); |