diff options
author | Justine Tunney <jart@google.com> | 2017-11-16 19:30:05 -0800 |
---|---|---|
committer | TensorFlower Gardener <gardener@tensorflow.org> | 2017-11-16 19:34:44 -0800 |
commit | 7d17d27940aa915583b0b3e2ba77d9f708af6783 (patch) | |
tree | 5022fc732328c2f9578876a3630348485bcbb666 /tensorflow/contrib/tensorboard | |
parent | 929178e1046f6387d9245c3d89ba5c3c1f3078d5 (diff) |
Add WriteScalar support to SummaryDbWriter
PiperOrigin-RevId: 176058700
Diffstat (limited to 'tensorflow/contrib/tensorboard')
-rw-r--r-- | tensorflow/contrib/tensorboard/db/summary_db_writer.cc | 81 | ||||
-rw-r--r-- | tensorflow/contrib/tensorboard/db/summary_db_writer_test.cc | 27 |
2 files changed, 89 insertions, 19 deletions
diff --git a/tensorflow/contrib/tensorboard/db/summary_db_writer.cc b/tensorflow/contrib/tensorboard/db/summary_db_writer.cc index ae063d24ef..857e731ef2 100644 --- a/tensorflow/contrib/tensorboard/db/summary_db_writer.cc +++ b/tensorflow/contrib/tensorboard/db/summary_db_writer.cc @@ -81,6 +81,55 @@ Status BindTensor(SqliteStatement* stmt, int parameter, const Tensor& t) { return BindProto(stmt, parameter, p); } +// Tries to fudge shape and dtype to something with smaller storage. +Status CoerceScalar(const Tensor& t, Tensor* out) { + switch (t.dtype()) { + case DT_DOUBLE: + *out = t; + break; + case DT_INT64: + *out = t; + break; + case DT_FLOAT: + *out = {DT_DOUBLE, {}}; + out->scalar<double>()() = t.scalar<float>()(); + break; + case DT_HALF: + *out = {DT_DOUBLE, {}}; + out->scalar<double>()() = static_cast<double>(t.scalar<Eigen::half>()()); + break; + case DT_INT32: + *out = {DT_INT64, {}}; + out->scalar<int64>()() = t.scalar<int32>()(); + break; + case DT_INT16: + *out = {DT_INT64, {}}; + out->scalar<int64>()() = t.scalar<int16>()(); + break; + case DT_INT8: + *out = {DT_INT64, {}}; + out->scalar<int64>()() = t.scalar<int8>()(); + break; + case DT_UINT32: + *out = {DT_INT64, {}}; + out->scalar<int64>()() = t.scalar<uint32>()(); + break; + case DT_UINT16: + *out = {DT_INT64, {}}; + out->scalar<int64>()() = t.scalar<uint16>()(); + break; + case DT_UINT8: + *out = {DT_INT64, {}}; + out->scalar<int64>()() = t.scalar<uint8>()(); + break; + default: + return errors::Unimplemented("Scalar summary for dtype ", + DataTypeString(t.dtype()), + " is not supported."); + } + return Status::OK(); +} + class Transactor { public: explicit Transactor(std::shared_ptr<Sqlite> db) @@ -280,20 +329,23 @@ class SummaryDbWriter : public SummaryWriterInterface { insert_tensor_.BindInt(1, tag_id); insert_tensor_.BindInt(2, global_step); insert_tensor_.BindDouble(3, GetWallTime(env_)); - switch (t.dtype()) { - case DT_INT64: - insert_tensor_.BindInt(4, t.scalar<int64>()()); - break; - case DT_DOUBLE: - insert_tensor_.BindDouble(4, t.scalar<double>()()); - break; - default: - TF_RETURN_IF_ERROR(BindTensor(&insert_tensor_, 4, t)); - break; + if (t.shape().dims() == 0 && t.dtype() == DT_INT64) { + insert_tensor_.BindInt(4, t.scalar<int64>()()); + } else if (t.shape().dims() == 0 && t.dtype() == DT_DOUBLE) { + insert_tensor_.BindDouble(4, t.scalar<double>()()); + } else { + TF_RETURN_IF_ERROR(BindTensor(&insert_tensor_, 4, t)); } return insert_tensor_.StepAndReset(); } + Status WriteScalar(int64 global_step, Tensor t, const string& tag) override { + Tensor t2; + TF_RETURN_IF_ERROR(CoerceScalar(t, &t2)); + // TODO(jart): Generate scalars plugin metadata on this value. + return WriteTensor(global_step, std::move(t2), tag, ""); + } + Status WriteGraph(int64 global_step, std::unique_ptr<GraphDef> g) override { mutex_lock ml(mu_); TF_RETURN_IF_ERROR(InitializeParents()); @@ -325,15 +377,6 @@ class SummaryDbWriter : public SummaryWriterInterface { } } - Status WriteScalar(int64 global_step, Tensor t, const string& tag) override { - // TODO(@jart): Unlike WriteTensor, this method would be granted leniency - // to change the dtype if it saves storage space. For example, - // DT_UINT32 would be stored in the database as an INTEGER - // rather than a serialized BLOB. But when reading it back, - // the dtype would become DT_INT64. - return errors::Unimplemented("WriteScalar"); - } - Status WriteHistogram(int64 global_step, Tensor t, const string& tag) override { return errors::Unimplemented( diff --git a/tensorflow/contrib/tensorboard/db/summary_db_writer_test.cc b/tensorflow/contrib/tensorboard/db/summary_db_writer_test.cc index 3431842ca2..625861fa6b 100644 --- a/tensorflow/contrib/tensorboard/db/summary_db_writer_test.cc +++ b/tensorflow/contrib/tensorboard/db/summary_db_writer_test.cc @@ -290,5 +290,32 @@ TEST_F(SummaryDbWriterTest, WriteGraph) { EXPECT_EQ(1LL, QueryInt("SELECT is_control FROM NodeInputs WHERE idx = 2")); } +TEST_F(SummaryDbWriterTest, WriteScalarInt32_CoercesToInt64) { + TF_ASSERT_OK(CreateSummaryDbWriter(db_, "", "", "", &env_, &writer_)); + Tensor t(DT_INT32, {}); + t.scalar<int32>()() = -17; + TF_ASSERT_OK(writer_->WriteScalar(1, t, "t")); + TF_ASSERT_OK(writer_->Flush()); + ASSERT_EQ(-17LL, QueryInt("SELECT tensor FROM Tensors")); +} + +TEST_F(SummaryDbWriterTest, WriteScalarInt8_CoercesToInt64) { + TF_ASSERT_OK(CreateSummaryDbWriter(db_, "", "", "", &env_, &writer_)); + Tensor t(DT_INT8, {}); + t.scalar<int8>()() = static_cast<int8>(-17); + TF_ASSERT_OK(writer_->WriteScalar(1, t, "t")); + TF_ASSERT_OK(writer_->Flush()); + ASSERT_EQ(-17LL, QueryInt("SELECT tensor FROM Tensors")); +} + +TEST_F(SummaryDbWriterTest, WriteScalarUint8_CoercesToInt64) { + TF_ASSERT_OK(CreateSummaryDbWriter(db_, "", "", "", &env_, &writer_)); + Tensor t(DT_UINT8, {}); + t.scalar<uint8>()() = static_cast<uint8>(254); + TF_ASSERT_OK(writer_->WriteScalar(1, t, "t")); + TF_ASSERT_OK(writer_->Flush()); + ASSERT_EQ(254LL, QueryInt("SELECT tensor FROM Tensors")); +} + } // namespace } // namespace tensorflow |