aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/contrib/tensorboard
diff options
context:
space:
mode:
authorGravatar Justine Tunney <jart@google.com>2017-11-16 19:30:05 -0800
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2017-11-16 19:34:44 -0800
commit7d17d27940aa915583b0b3e2ba77d9f708af6783 (patch)
tree5022fc732328c2f9578876a3630348485bcbb666 /tensorflow/contrib/tensorboard
parent929178e1046f6387d9245c3d89ba5c3c1f3078d5 (diff)
Add WriteScalar support to SummaryDbWriter
PiperOrigin-RevId: 176058700
Diffstat (limited to 'tensorflow/contrib/tensorboard')
-rw-r--r--tensorflow/contrib/tensorboard/db/summary_db_writer.cc81
-rw-r--r--tensorflow/contrib/tensorboard/db/summary_db_writer_test.cc27
2 files changed, 89 insertions, 19 deletions
diff --git a/tensorflow/contrib/tensorboard/db/summary_db_writer.cc b/tensorflow/contrib/tensorboard/db/summary_db_writer.cc
index ae063d24ef..857e731ef2 100644
--- a/tensorflow/contrib/tensorboard/db/summary_db_writer.cc
+++ b/tensorflow/contrib/tensorboard/db/summary_db_writer.cc
@@ -81,6 +81,55 @@ Status BindTensor(SqliteStatement* stmt, int parameter, const Tensor& t) {
return BindProto(stmt, parameter, p);
}
+// Tries to fudge shape and dtype to something with smaller storage.
+Status CoerceScalar(const Tensor& t, Tensor* out) {
+ switch (t.dtype()) {
+ case DT_DOUBLE:
+ *out = t;
+ break;
+ case DT_INT64:
+ *out = t;
+ break;
+ case DT_FLOAT:
+ *out = {DT_DOUBLE, {}};
+ out->scalar<double>()() = t.scalar<float>()();
+ break;
+ case DT_HALF:
+ *out = {DT_DOUBLE, {}};
+ out->scalar<double>()() = static_cast<double>(t.scalar<Eigen::half>()());
+ break;
+ case DT_INT32:
+ *out = {DT_INT64, {}};
+ out->scalar<int64>()() = t.scalar<int32>()();
+ break;
+ case DT_INT16:
+ *out = {DT_INT64, {}};
+ out->scalar<int64>()() = t.scalar<int16>()();
+ break;
+ case DT_INT8:
+ *out = {DT_INT64, {}};
+ out->scalar<int64>()() = t.scalar<int8>()();
+ break;
+ case DT_UINT32:
+ *out = {DT_INT64, {}};
+ out->scalar<int64>()() = t.scalar<uint32>()();
+ break;
+ case DT_UINT16:
+ *out = {DT_INT64, {}};
+ out->scalar<int64>()() = t.scalar<uint16>()();
+ break;
+ case DT_UINT8:
+ *out = {DT_INT64, {}};
+ out->scalar<int64>()() = t.scalar<uint8>()();
+ break;
+ default:
+ return errors::Unimplemented("Scalar summary for dtype ",
+ DataTypeString(t.dtype()),
+ " is not supported.");
+ }
+ return Status::OK();
+}
+
class Transactor {
public:
explicit Transactor(std::shared_ptr<Sqlite> db)
@@ -280,20 +329,23 @@ class SummaryDbWriter : public SummaryWriterInterface {
insert_tensor_.BindInt(1, tag_id);
insert_tensor_.BindInt(2, global_step);
insert_tensor_.BindDouble(3, GetWallTime(env_));
- switch (t.dtype()) {
- case DT_INT64:
- insert_tensor_.BindInt(4, t.scalar<int64>()());
- break;
- case DT_DOUBLE:
- insert_tensor_.BindDouble(4, t.scalar<double>()());
- break;
- default:
- TF_RETURN_IF_ERROR(BindTensor(&insert_tensor_, 4, t));
- break;
+ if (t.shape().dims() == 0 && t.dtype() == DT_INT64) {
+ insert_tensor_.BindInt(4, t.scalar<int64>()());
+ } else if (t.shape().dims() == 0 && t.dtype() == DT_DOUBLE) {
+ insert_tensor_.BindDouble(4, t.scalar<double>()());
+ } else {
+ TF_RETURN_IF_ERROR(BindTensor(&insert_tensor_, 4, t));
}
return insert_tensor_.StepAndReset();
}
+ Status WriteScalar(int64 global_step, Tensor t, const string& tag) override {
+ Tensor t2;
+ TF_RETURN_IF_ERROR(CoerceScalar(t, &t2));
+ // TODO(jart): Generate scalars plugin metadata on this value.
+ return WriteTensor(global_step, std::move(t2), tag, "");
+ }
+
Status WriteGraph(int64 global_step, std::unique_ptr<GraphDef> g) override {
mutex_lock ml(mu_);
TF_RETURN_IF_ERROR(InitializeParents());
@@ -325,15 +377,6 @@ class SummaryDbWriter : public SummaryWriterInterface {
}
}
- Status WriteScalar(int64 global_step, Tensor t, const string& tag) override {
- // TODO(@jart): Unlike WriteTensor, this method would be granted leniency
- // to change the dtype if it saves storage space. For example,
- // DT_UINT32 would be stored in the database as an INTEGER
- // rather than a serialized BLOB. But when reading it back,
- // the dtype would become DT_INT64.
- return errors::Unimplemented("WriteScalar");
- }
-
Status WriteHistogram(int64 global_step, Tensor t,
const string& tag) override {
return errors::Unimplemented(
diff --git a/tensorflow/contrib/tensorboard/db/summary_db_writer_test.cc b/tensorflow/contrib/tensorboard/db/summary_db_writer_test.cc
index 3431842ca2..625861fa6b 100644
--- a/tensorflow/contrib/tensorboard/db/summary_db_writer_test.cc
+++ b/tensorflow/contrib/tensorboard/db/summary_db_writer_test.cc
@@ -290,5 +290,32 @@ TEST_F(SummaryDbWriterTest, WriteGraph) {
EXPECT_EQ(1LL, QueryInt("SELECT is_control FROM NodeInputs WHERE idx = 2"));
}
+TEST_F(SummaryDbWriterTest, WriteScalarInt32_CoercesToInt64) {
+ TF_ASSERT_OK(CreateSummaryDbWriter(db_, "", "", "", &env_, &writer_));
+ Tensor t(DT_INT32, {});
+ t.scalar<int32>()() = -17;
+ TF_ASSERT_OK(writer_->WriteScalar(1, t, "t"));
+ TF_ASSERT_OK(writer_->Flush());
+ ASSERT_EQ(-17LL, QueryInt("SELECT tensor FROM Tensors"));
+}
+
+TEST_F(SummaryDbWriterTest, WriteScalarInt8_CoercesToInt64) {
+ TF_ASSERT_OK(CreateSummaryDbWriter(db_, "", "", "", &env_, &writer_));
+ Tensor t(DT_INT8, {});
+ t.scalar<int8>()() = static_cast<int8>(-17);
+ TF_ASSERT_OK(writer_->WriteScalar(1, t, "t"));
+ TF_ASSERT_OK(writer_->Flush());
+ ASSERT_EQ(-17LL, QueryInt("SELECT tensor FROM Tensors"));
+}
+
+TEST_F(SummaryDbWriterTest, WriteScalarUint8_CoercesToInt64) {
+ TF_ASSERT_OK(CreateSummaryDbWriter(db_, "", "", "", &env_, &writer_));
+ Tensor t(DT_UINT8, {});
+ t.scalar<uint8>()() = static_cast<uint8>(254);
+ TF_ASSERT_OK(writer_->WriteScalar(1, t, "t"));
+ TF_ASSERT_OK(writer_->Flush());
+ ASSERT_EQ(254LL, QueryInt("SELECT tensor FROM Tensors"));
+}
+
} // namespace
} // namespace tensorflow