diff options
Diffstat (limited to 'tensorflow/contrib/image/kernels/image_ops.cc')
-rw-r--r-- | tensorflow/contrib/image/kernels/image_ops.cc | 24 |
1 files changed, 20 insertions, 4 deletions
diff --git a/tensorflow/contrib/image/kernels/image_ops.cc b/tensorflow/contrib/image/kernels/image_ops.cc index 8d50541771..8a97f07732 100644 --- a/tensorflow/contrib/image/kernels/image_ops.cc +++ b/tensorflow/contrib/image/kernels/image_ops.cc @@ -43,13 +43,29 @@ template class FillProjectiveTransform<CPUDevice, double>; typedef Eigen::ThreadPoolDevice CPUDevice; using functor::FillProjectiveTransform; +using generator::INTERPOLATION_BILINEAR; +using generator::INTERPOLATION_NEAREST; +using generator::Interpolation; using generator::ProjectiveGenerator; template <typename Device, typename T> class ImageProjectiveTransform : public OpKernel { + private: + Interpolation interpolation_; + public: - explicit ImageProjectiveTransform(OpKernelConstruction* ctx) - : OpKernel(ctx) {} + explicit ImageProjectiveTransform(OpKernelConstruction* ctx) : OpKernel(ctx) { + string interpolation_str; + OP_REQUIRES_OK(ctx, ctx->GetAttr("interpolation", &interpolation_str)); + if (interpolation_str == "NEAREST") { + interpolation_ = INTERPOLATION_NEAREST; + } else if (interpolation_str == "BILINEAR") { + interpolation_ = INTERPOLATION_BILINEAR; + } else { + LOG(FATAL) << "Invalid interpolation " << interpolation_str + << ". Supported types: NEAREST, BILINEAR"; + } + } void Compute(OpKernelContext* ctx) override { const Tensor& images_t = ctx->input(0); @@ -68,8 +84,8 @@ class ImageProjectiveTransform : public OpKernel { Tensor* output_t; OP_REQUIRES_OK(ctx, ctx->allocate_output(0, images_t.shape(), &output_t)); auto output = output_t->tensor<T, 4>(); - const FillProjectiveTransform<Device, T> functor; - functor(ctx->eigen_device<Device>(), &output, images, transform); + (FillProjectiveTransform<Device, T>(interpolation_))( + ctx->eigen_device<Device>(), &output, images, transform); } }; |