/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ #include "tensorflow/contrib/tensorboard/db/schema.h" #include "tensorflow/contrib/tensorboard/db/summary_db_writer.h" #include "tensorflow/contrib/tensorboard/db/summary_file_writer.h" #include "tensorflow/core/framework/graph.pb.h" #include "tensorflow/core/framework/op_kernel.h" #include "tensorflow/core/framework/resource_mgr.h" #include "tensorflow/core/lib/db/sqlite.h" #include "tensorflow/core/platform/protobuf.h" #include "tensorflow/core/util/event.pb.h" namespace tensorflow { REGISTER_KERNEL_BUILDER(Name("SummaryWriter").Device(DEVICE_CPU), ResourceHandleOp); class CreateSummaryFileWriterOp : public OpKernel { public: explicit CreateSummaryFileWriterOp(OpKernelConstruction* ctx) : OpKernel(ctx) {} void Compute(OpKernelContext* ctx) override { const Tensor* tmp; OP_REQUIRES_OK(ctx, ctx->input("logdir", &tmp)); const string logdir = tmp->scalar()(); OP_REQUIRES_OK(ctx, ctx->input("max_queue", &tmp)); const int32 max_queue = tmp->scalar()(); OP_REQUIRES_OK(ctx, ctx->input("flush_millis", &tmp)); const int32 flush_millis = tmp->scalar()(); OP_REQUIRES_OK(ctx, ctx->input("filename_suffix", &tmp)); const string filename_suffix = tmp->scalar()(); SummaryWriterInterface* s = nullptr; OP_REQUIRES_OK(ctx, LookupOrCreateResource( ctx, HandleFromInput(ctx, 0), &s, [max_queue, flush_millis, logdir, filename_suffix, ctx](SummaryWriterInterface** s) { return CreateSummaryFileWriter( max_queue, flush_millis, logdir, filename_suffix, ctx->env(), s); })); } }; REGISTER_KERNEL_BUILDER(Name("CreateSummaryFileWriter").Device(DEVICE_CPU), CreateSummaryFileWriterOp); class CreateSummaryDbWriterOp : public OpKernel { public: explicit CreateSummaryDbWriterOp(OpKernelConstruction* ctx) : OpKernel(ctx) {} void Compute(OpKernelContext* ctx) override { const Tensor* tmp; OP_REQUIRES_OK(ctx, ctx->input("db_uri", &tmp)); const string db_uri = tmp->scalar()(); OP_REQUIRES_OK(ctx, ctx->input("experiment_name", &tmp)); const string experiment_name = tmp->scalar()(); OP_REQUIRES_OK(ctx, ctx->input("run_name", &tmp)); const string run_name = tmp->scalar()(); OP_REQUIRES_OK(ctx, ctx->input("user_name", &tmp)); const string user_name = tmp->scalar()(); SummaryWriterInterface* s = nullptr; OP_REQUIRES_OK( ctx, LookupOrCreateResource( ctx, HandleFromInput(ctx, 0), &s, [db_uri, experiment_name, run_name, user_name, ctx](SummaryWriterInterface** s) { Sqlite* db; TF_RETURN_IF_ERROR(Sqlite::Open( db_uri, SQLITE_OPEN_READWRITE | SQLITE_OPEN_CREATE, &db)); core::ScopedUnref unref(db); TF_RETURN_IF_ERROR(SetupTensorboardSqliteDb(db)); TF_RETURN_IF_ERROR(CreateSummaryDbWriter( db, experiment_name, run_name, user_name, ctx->env(), s)); return Status::OK(); })); } }; REGISTER_KERNEL_BUILDER(Name("CreateSummaryDbWriter").Device(DEVICE_CPU), CreateSummaryDbWriterOp); class FlushSummaryWriterOp : public OpKernel { public: explicit FlushSummaryWriterOp(OpKernelConstruction* ctx) : OpKernel(ctx) {} void Compute(OpKernelContext* ctx) override { SummaryWriterInterface* s; OP_REQUIRES_OK(ctx, LookupResource(ctx, HandleFromInput(ctx, 0), &s)); core::ScopedUnref unref(s); OP_REQUIRES_OK(ctx, s->Flush()); } }; REGISTER_KERNEL_BUILDER(Name("FlushSummaryWriter").Device(DEVICE_CPU), FlushSummaryWriterOp); class CloseSummaryWriterOp : public OpKernel { public: explicit CloseSummaryWriterOp(OpKernelConstruction* ctx) : OpKernel(ctx) {} void Compute(OpKernelContext* ctx) override { OP_REQUIRES_OK(ctx, DeleteResource( ctx, HandleFromInput(ctx, 0))); } }; REGISTER_KERNEL_BUILDER(Name("CloseSummaryWriter").Device(DEVICE_CPU), CloseSummaryWriterOp); class WriteSummaryOp : public OpKernel { public: explicit WriteSummaryOp(OpKernelConstruction* ctx) : OpKernel(ctx) {} void Compute(OpKernelContext* ctx) override { SummaryWriterInterface* s; OP_REQUIRES_OK(ctx, LookupResource(ctx, HandleFromInput(ctx, 0), &s)); core::ScopedUnref unref(s); const Tensor* tmp; OP_REQUIRES_OK(ctx, ctx->input("step", &tmp)); const int64 step = tmp->scalar()(); OP_REQUIRES_OK(ctx, ctx->input("tag", &tmp)); const string& tag = tmp->scalar()(); OP_REQUIRES_OK(ctx, ctx->input("summary_metadata", &tmp)); const string& serialized_metadata = tmp->scalar()(); const Tensor* t; OP_REQUIRES_OK(ctx, ctx->input("tensor", &t)); OP_REQUIRES_OK(ctx, s->WriteTensor(step, *t, tag, serialized_metadata)); } }; REGISTER_KERNEL_BUILDER(Name("WriteSummary").Device(DEVICE_CPU), WriteSummaryOp); class ImportEventOp : public OpKernel { public: explicit ImportEventOp(OpKernelConstruction* ctx) : OpKernel(ctx) {} void Compute(OpKernelContext* ctx) override { SummaryWriterInterface* s; OP_REQUIRES_OK(ctx, LookupResource(ctx, HandleFromInput(ctx, 0), &s)); core::ScopedUnref unref(s); const Tensor* t; OP_REQUIRES_OK(ctx, ctx->input("event", &t)); std::unique_ptr event{new Event}; if (!ParseProtoUnlimited(event.get(), t->scalar()())) { ctx->CtxFailureWithWarning( errors::DataLoss("Bad tf.Event binary proto tensor string")); return; } OP_REQUIRES_OK(ctx, s->WriteEvent(std::move(event))); } }; REGISTER_KERNEL_BUILDER(Name("ImportEvent").Device(DEVICE_CPU), ImportEventOp); class WriteScalarSummaryOp : public OpKernel { public: explicit WriteScalarSummaryOp(OpKernelConstruction* ctx) : OpKernel(ctx) {} void Compute(OpKernelContext* ctx) override { SummaryWriterInterface* s; OP_REQUIRES_OK(ctx, LookupResource(ctx, HandleFromInput(ctx, 0), &s)); core::ScopedUnref unref(s); const Tensor* tmp; OP_REQUIRES_OK(ctx, ctx->input("step", &tmp)); const int64 step = tmp->scalar()(); OP_REQUIRES_OK(ctx, ctx->input("tag", &tmp)); const string& tag = tmp->scalar()(); const Tensor* t; OP_REQUIRES_OK(ctx, ctx->input("value", &t)); OP_REQUIRES_OK(ctx, s->WriteScalar(step, *t, tag)); } }; REGISTER_KERNEL_BUILDER(Name("WriteScalarSummary").Device(DEVICE_CPU), WriteScalarSummaryOp); class WriteHistogramSummaryOp : public OpKernel { public: explicit WriteHistogramSummaryOp(OpKernelConstruction* ctx) : OpKernel(ctx) {} void Compute(OpKernelContext* ctx) override { SummaryWriterInterface* s; OP_REQUIRES_OK(ctx, LookupResource(ctx, HandleFromInput(ctx, 0), &s)); core::ScopedUnref unref(s); const Tensor* tmp; OP_REQUIRES_OK(ctx, ctx->input("step", &tmp)); const int64 step = tmp->scalar()(); OP_REQUIRES_OK(ctx, ctx->input("tag", &tmp)); const string& tag = tmp->scalar()(); const Tensor* t; OP_REQUIRES_OK(ctx, ctx->input("values", &t)); OP_REQUIRES_OK(ctx, s->WriteHistogram(step, *t, tag)); } }; REGISTER_KERNEL_BUILDER(Name("WriteHistogramSummary").Device(DEVICE_CPU), WriteHistogramSummaryOp); class WriteImageSummaryOp : public OpKernel { public: explicit WriteImageSummaryOp(OpKernelConstruction* ctx) : OpKernel(ctx) { int64 max_images_tmp; OP_REQUIRES_OK(ctx, ctx->GetAttr("max_images", &max_images_tmp)); OP_REQUIRES(ctx, max_images_tmp < (1LL << 31), errors::InvalidArgument("max_images must be < 2^31")); max_images_ = static_cast(max_images_tmp); } void Compute(OpKernelContext* ctx) override { SummaryWriterInterface* s; OP_REQUIRES_OK(ctx, LookupResource(ctx, HandleFromInput(ctx, 0), &s)); core::ScopedUnref unref(s); const Tensor* tmp; OP_REQUIRES_OK(ctx, ctx->input("step", &tmp)); const int64 step = tmp->scalar()(); OP_REQUIRES_OK(ctx, ctx->input("tag", &tmp)); const string& tag = tmp->scalar()(); const Tensor* bad_color; OP_REQUIRES_OK(ctx, ctx->input("bad_color", &bad_color)); OP_REQUIRES( ctx, TensorShapeUtils::IsVector(bad_color->shape()), errors::InvalidArgument("bad_color must be a vector, got shape ", bad_color->shape().DebugString())); const Tensor* t; OP_REQUIRES_OK(ctx, ctx->input("tensor", &t)); OP_REQUIRES_OK(ctx, s->WriteImage(step, *t, tag, max_images_, *bad_color)); } private: int32 max_images_; }; REGISTER_KERNEL_BUILDER(Name("WriteImageSummary").Device(DEVICE_CPU), WriteImageSummaryOp); class WriteAudioSummaryOp : public OpKernel { public: explicit WriteAudioSummaryOp(OpKernelConstruction* ctx) : OpKernel(ctx) { OP_REQUIRES_OK(ctx, ctx->GetAttr("max_outputs", &max_outputs_)); OP_REQUIRES(ctx, max_outputs_ > 0, errors::InvalidArgument("max_outputs must be > 0")); } void Compute(OpKernelContext* ctx) override { SummaryWriterInterface* s; OP_REQUIRES_OK(ctx, LookupResource(ctx, HandleFromInput(ctx, 0), &s)); core::ScopedUnref unref(s); const Tensor* tmp; OP_REQUIRES_OK(ctx, ctx->input("step", &tmp)); const int64 step = tmp->scalar()(); OP_REQUIRES_OK(ctx, ctx->input("tag", &tmp)); const string& tag = tmp->scalar()(); OP_REQUIRES_OK(ctx, ctx->input("sample_rate", &tmp)); const float sample_rate = tmp->scalar()(); const Tensor* t; OP_REQUIRES_OK(ctx, ctx->input("tensor", &t)); OP_REQUIRES_OK(ctx, s->WriteAudio(step, *t, tag, max_outputs_, sample_rate)); } private: int max_outputs_; }; REGISTER_KERNEL_BUILDER(Name("WriteAudioSummary").Device(DEVICE_CPU), WriteAudioSummaryOp); class WriteGraphSummaryOp : public OpKernel { public: explicit WriteGraphSummaryOp(OpKernelConstruction* ctx) : OpKernel(ctx) {} void Compute(OpKernelContext* ctx) override { SummaryWriterInterface* s; OP_REQUIRES_OK(ctx, LookupResource(ctx, HandleFromInput(ctx, 0), &s)); core::ScopedUnref unref(s); const Tensor* t; OP_REQUIRES_OK(ctx, ctx->input("step", &t)); const int64 step = t->scalar()(); OP_REQUIRES_OK(ctx, ctx->input("tensor", &t)); std::unique_ptr graph{new GraphDef}; if (!ParseProtoUnlimited(graph.get(), t->scalar()())) { ctx->CtxFailureWithWarning( errors::DataLoss("Bad tf.GraphDef binary proto tensor string")); return; } OP_REQUIRES_OK(ctx, s->WriteGraph(step, std::move(graph))); } }; REGISTER_KERNEL_BUILDER(Name("WriteGraphSummary").Device(DEVICE_CPU), WriteGraphSummaryOp); } // namespace tensorflow