aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/core/kernels/summary_op.cc
diff options
context:
space:
mode:
Diffstat (limited to 'tensorflow/core/kernels/summary_op.cc')
-rw-r--r--tensorflow/core/kernels/summary_op.cc141
1 files changed, 141 insertions, 0 deletions
diff --git a/tensorflow/core/kernels/summary_op.cc b/tensorflow/core/kernels/summary_op.cc
new file mode 100644
index 0000000000..1c4be64b8b
--- /dev/null
+++ b/tensorflow/core/kernels/summary_op.cc
@@ -0,0 +1,141 @@
+// Operators that deal with SummaryProtos (encoded as DT_STRING tensors) as
+// inputs or outputs in various ways.
+
+// See docs in ../ops/summary_ops.cc.
+
+#include <unordered_set>
+
+#include "tensorflow/core/framework/op_kernel.h"
+#include "tensorflow/core/framework/resource_mgr.h"
+#include "tensorflow/core/framework/summary.pb.h"
+#include "tensorflow/core/lib/core/errors.h"
+#include "tensorflow/core/lib/histogram/histogram.h"
+#include "tensorflow/core/platform/logging.h"
+#include "tensorflow/core/platform/protobuf.h"
+
+namespace tensorflow {
+
+template <typename T>
+class SummaryScalarOp : public OpKernel {
+ public:
+ explicit SummaryScalarOp(OpKernelConstruction* context) : OpKernel(context) {}
+
+ void Compute(OpKernelContext* c) override {
+ const Tensor& tags = c->input(0);
+ const Tensor& values = c->input(1);
+
+ OP_REQUIRES(c, tags.IsSameSize(values) ||
+ (TensorShapeUtils::IsLegacyScalar(tags.shape()) &&
+ TensorShapeUtils::IsLegacyScalar(values.shape())),
+ errors::InvalidArgument("tags and values not the same shape: ",
+ tags.shape().ShortDebugString(), " != ",
+ values.shape().ShortDebugString()));
+ auto Ttags = tags.flat<string>();
+ auto Tvalues = values.flat<T>();
+ Summary s;
+ for (int i = 0; i < Ttags.size(); i++) {
+ Summary::Value* v = s.add_value();
+ v->set_tag(Ttags(i));
+ v->set_simple_value(Tvalues(i));
+ }
+
+ Tensor* summary_tensor = nullptr;
+ OP_REQUIRES_OK(c, c->allocate_output(0, TensorShape({}), &summary_tensor));
+ CHECK(s.SerializeToString(&summary_tensor->scalar<string>()()));
+ }
+};
+
+REGISTER_KERNEL_BUILDER(Name("ScalarSummary")
+ .Device(DEVICE_CPU)
+ .TypeConstraint<float>("T"),
+ SummaryScalarOp<float>);
+REGISTER_KERNEL_BUILDER(Name("ScalarSummary")
+ .Device(DEVICE_CPU)
+ .TypeConstraint<double>("T"),
+ SummaryScalarOp<double>);
+
+class SummaryHistoOp : public OpKernel {
+ public:
+ // SummaryHistoOp could be extended to take a list of custom bucket
+ // boundaries as an option.
+ explicit SummaryHistoOp(OpKernelConstruction* context) : OpKernel(context) {}
+
+ void Compute(OpKernelContext* c) override {
+ const Tensor& tags = c->input(0);
+ const Tensor& values = c->input(1);
+ const auto flat = values.flat<float>();
+ OP_REQUIRES(c, TensorShapeUtils::IsLegacyScalar(tags.shape()),
+ errors::InvalidArgument("tags must be scalar"));
+ // Build histogram of values in "values" tensor
+ histogram::Histogram histo;
+ for (int64 i = 0; i < flat.size(); i++) {
+ float v = flat(i);
+ if (!std::isfinite(v)) {
+ c->SetStatus(
+ errors::OutOfRange("Nan in summary histogram for: ", name()));
+ break;
+ }
+ histo.Add(v);
+ }
+
+ Summary s;
+ Summary::Value* v = s.add_value();
+ v->set_tag(tags.scalar<string>()());
+ histo.EncodeToProto(v->mutable_histo(), false /* Drop zero buckets */);
+
+ Tensor* summary_tensor = nullptr;
+ OP_REQUIRES_OK(c, c->allocate_output(0, TensorShape({}), &summary_tensor));
+ CHECK(s.SerializeToString(&summary_tensor->scalar<string>()()));
+ }
+};
+
+REGISTER_KERNEL_BUILDER(Name("HistogramSummary").Device(DEVICE_CPU),
+ SummaryHistoOp);
+
+struct HistogramResource : public ResourceBase {
+ histogram::ThreadSafeHistogram histogram;
+
+ string DebugString() override { return "A historam summary. Stats ..."; }
+};
+
+class SummaryMergeOp : public OpKernel {
+ public:
+ explicit SummaryMergeOp(OpKernelConstruction* context) : OpKernel(context) {}
+
+ void Compute(OpKernelContext* c) override {
+ Summary s;
+ std::unordered_set<string> tags;
+ for (int input_num = 0; input_num < c->num_inputs(); input_num++) {
+ const Tensor& in = c->input(input_num);
+ auto in_vec = in.flat<string>();
+ for (int i = 0; i < in_vec.dimension(0); i++) {
+ const string& s_in = in_vec(i);
+ Summary summary_in;
+ if (!ParseProtoUnlimited(&summary_in, s_in)) {
+ c->SetStatus(errors::InvalidArgument(
+ "Could not parse one of the summary inputs"));
+ return;
+ }
+
+ for (int v = 0; v < summary_in.value_size(); v++) {
+ if (!tags.insert(summary_in.value(v).tag()).second) {
+ c->SetStatus(errors::InvalidArgument(
+ strings::StrCat("Duplicate tag ", summary_in.value(v).tag(),
+ " found in summary inputs")));
+ return;
+ }
+ *s.add_value() = summary_in.value(v);
+ }
+ }
+ }
+
+ Tensor* summary_tensor = nullptr;
+ OP_REQUIRES_OK(c, c->allocate_output(0, TensorShape({}), &summary_tensor));
+ CHECK(s.SerializeToString(&summary_tensor->scalar<string>()()));
+ }
+};
+
+REGISTER_KERNEL_BUILDER(Name("MergeSummary").Device(DEVICE_CPU),
+ SummaryMergeOp);
+
+} // namespace tensorflow