diff options
Diffstat (limited to 'tensorflow/core/kernels/summary_image_op.cc')
-rw-r--r-- | tensorflow/core/kernels/summary_image_op.cc | 169 |
1 files changed, 169 insertions, 0 deletions
diff --git a/tensorflow/core/kernels/summary_image_op.cc b/tensorflow/core/kernels/summary_image_op.cc new file mode 100644 index 0000000000..ba765f2e84 --- /dev/null +++ b/tensorflow/core/kernels/summary_image_op.cc @@ -0,0 +1,169 @@ +// Operators that deal with SummaryProtos (encoded as DT_STRING tensors) as +// inputs or outputs in various ways. + +// See docs in ../ops/summary_ops.cc. + +#include "tensorflow/core/framework/op_kernel.h" +#include "tensorflow/core/framework/summary.pb.h" +#include "tensorflow/core/platform/logging.h" +#include "tensorflow/core/lib/core/errors.h" +#include "tensorflow/core/lib/png/png_io.h" + +namespace tensorflow { + +class SummaryImageOp : public OpKernel { + public: + explicit SummaryImageOp(OpKernelConstruction* context) : OpKernel(context) { + OP_REQUIRES_OK(context, context->GetAttr("max_images", &max_images_)); + const TensorProto* proto; + OP_REQUIRES_OK(context, context->GetAttr("bad_color", &proto)); + OP_REQUIRES_OK(context, context->device()->MakeTensorFromProto( + *proto, AllocatorAttributes(), &bad_color_)); + OP_REQUIRES(context, bad_color_.dtype() == DT_UINT8, + errors::InvalidArgument("bad_color must be uint8, got ", + DataTypeString(bad_color_.dtype()))); + OP_REQUIRES( + context, TensorShapeUtils::IsVector(bad_color_.shape()), + errors::InvalidArgument("bad_color must be a vector, got shape ", + bad_color_.shape().ShortDebugString())); + } + + void Compute(OpKernelContext* c) override { + const Tensor& tags = c->input(0); + const Tensor& tensor = c->input(1); + OP_REQUIRES(c, TensorShapeUtils::IsLegacyScalar(tags.shape()), + errors::InvalidArgument("Tags must have be a scalar")); + OP_REQUIRES(c, tensor.dims() == 4 && + (tensor.dim_size(3) == 1 || tensor.dim_size(3) == 3 || + tensor.dim_size(3) == 4), + errors::InvalidArgument( + "Tensor must be 4-D with last dim 1, 3, or 4, not ", + tensor.shape().DebugString())); + const string& base_tag = tags.scalar<string>()(); + + const int batch_size = tensor.dim_size(0); + const int h = tensor.dim_size(1); + const int w = tensor.dim_size(2); + const int hw = h * w; // Compact these two dims for simplicity + const int depth = tensor.dim_size(3); + auto tensor_eigen = tensor.shaped<float, 3>({batch_size, hw, depth}); + + OP_REQUIRES(c, bad_color_.dim_size(0) >= depth, + errors::InvalidArgument( + "expected depth <= bad_color.size, got depth = ", depth, + ", bad_color.size = ", bad_color_.dim_size(0))); + auto bad_color_full = bad_color_.vec<uint8>(); + typename TTypes<uint8>::Vec bad_color(bad_color_full.data(), depth); + + // RGB (or gray or RGBA) is last dimension + Eigen::Tensor<uint8, 2, Eigen::RowMajor> image(hw, depth); + + Summary s; + const int N = std::min<int>(max_images_, batch_size); + for (int i = 0; i < N; ++i) { + Summary::Value* v = s.add_value(); + // The tag depends on the number of requested images (not the number + // produced.) + // + // Note that later on avisu uses "/" to figure out a consistent naming + // convention for display, so we append "/image" to guarantee that the + // image(s) won't be displayed in the global scope with no name. + if (max_images_ > 1) { + v->set_tag(strings::StrCat(base_tag, "/image/", i)); + } else { + v->set_tag(strings::StrCat(base_tag, "/image")); + } + + if (image.size()) { + typename TTypes<float>::ConstMatrix values( + &tensor_eigen(i, 0, 0), + Eigen::DSizes<Eigen::DenseIndex, 2>(hw, depth)); + + // Rescale the image to uint8 range. + // + // We are trying to generate an RCG image from a float tensor. We do + // not have any info about the expected range of values in the tensor + // but the generated image needs to have all RGB values within [0, 255]. + // + // We use two different algorithms to generate these values. If the + // tensor has only positive values we scale them all by 255/max(values). + // If the tensor has both negative and positive values we scale them by + // the max of their absolute values and center them around 127. + // + // This works for most cases, but has the incovenient of not respecting + // the relative dynamic range across different instances of the tensor. + + // Compute min and max ignoring nonfinite pixels + float image_min = std::numeric_limits<float>::infinity(); + float image_max = -image_min; + for (int i = 0; i < hw; i++) { + bool finite = true; + for (int j = 0; j < depth; j++) { + if (!std::isfinite(values(i, j))) { + finite = false; + break; + } + } + if (finite) { + for (int j = 0; j < depth; j++) { + float value = values(i, j); + image_min = std::min(image_min, value); + image_max = std::max(image_max, value); + } + } + } + + // Pick an affine transform into uint8 + const float kZeroThreshold = 1e-6; + float scale, offset; + if (image_min < 0) { + float max_val = std::max(std::abs(image_min), std::abs(image_max)); + scale = max_val < kZeroThreshold ? 0.0f : 127.0f / max_val; + offset = 128.0f; + } else { + scale = image_max < kZeroThreshold ? 0.0f : 255.0f / image_max; + offset = 0.0f; + } + + // Transform image, turning nonfinite values to bad_color + for (int i = 0; i < hw; i++) { + bool finite = true; + for (int j = 0; j < depth; j++) { + if (!std::isfinite(values(i, j))) { + finite = false; + break; + } + } + if (finite) { + image.chip<0>(i) = + (values.chip<0>(i) * scale + offset).cast<uint8>(); + } else { + image.chip<0>(i) = bad_color; + } + } + } + + Summary::Image* si = v->mutable_image(); + si->set_height(h); + si->set_width(w); + si->set_colorspace(depth); + OP_REQUIRES(c, png::WriteImageToBuffer( + image.data(), w, h, w * depth, depth, 8, -1, + si->mutable_encoded_image_string(), nullptr), + errors::Internal("PNG encoding failed")); + } + + Tensor* summary_tensor = nullptr; + OP_REQUIRES_OK(c, c->allocate_output(0, TensorShape({}), &summary_tensor)); + CHECK(s.SerializeToString(&summary_tensor->scalar<string>()())); + } + + private: + int64 max_images_; + Tensor bad_color_; +}; + +REGISTER_KERNEL_BUILDER(Name("ImageSummary").Device(DEVICE_CPU), + SummaryImageOp); + +} // namespace tensorflow |