aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/core/kernels/encode_png_op.cc
blob: e01ec9afe90e7e60b90c760db916abb1d51701a9 (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
// See docs in ../ops/image_ops.cc

#include <memory>
#include "tensorflow/core/framework/op_kernel.h"
#include "tensorflow/core/framework/register_types.h"
#include "tensorflow/core/framework/types.h"
#include "tensorflow/core/lib/png/png_io.h"
#include "tensorflow/core/platform/logging.h"
#include "tensorflow/core/public/status.h"
#include "tensorflow/core/public/tensor.h"
#include "tensorflow/core/public/tensor_shape.h"

namespace tensorflow {

// Encode an image to a PNG stream
class EncodePngOp : public OpKernel {
 public:
  explicit EncodePngOp(OpKernelConstruction* context) : OpKernel(context) {
    OP_REQUIRES_OK(context, context->GetAttr("compression", &compression_));
    OP_REQUIRES(context, -1 <= compression_ && compression_ <= 9,
                errors::InvalidArgument("compression should be in [-1,9], got ",
                                        compression_));
  }

  void Compute(OpKernelContext* context) override {
    const Tensor& image = context->input(0);
    OP_REQUIRES(context, image.dims() == 3,
                errors::InvalidArgument("image must be 3-dimensional",
                                        image.shape().ShortDebugString()));
    const int64 channels = image.dim_size(2);
    OP_REQUIRES(context, channels == 1 || channels == 3 || channels == 4,
                errors::InvalidArgument(
                    "image must have 1, 3, or 4 channels, got ", channels));

    // Encode image to png string
    Tensor* output = NULL;
    OP_REQUIRES_OK(context,
                   context->allocate_output(0, TensorShape({}), &output));
    OP_REQUIRES(context,
                png::WriteImageToBuffer(
                    image.flat<uint8>().data(), image.dim_size(1),
                    image.dim_size(0), image.dim_size(1) * channels, channels,
                    8, compression_, &output->scalar<string>()(), nullptr),
                errors::Internal("PNG encoding failed"));
  }

 private:
  int compression_;
};
REGISTER_KERNEL_BUILDER(Name("EncodePng").Device(DEVICE_CPU), EncodePngOp);

}  // namespace tensorflow