aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/contrib/image/kernels/image_ops.cc
diff options
context:
space:
mode:
Diffstat (limited to 'tensorflow/contrib/image/kernels/image_ops.cc')
-rw-r--r--tensorflow/contrib/image/kernels/image_ops.cc24
1 files changed, 20 insertions, 4 deletions
diff --git a/tensorflow/contrib/image/kernels/image_ops.cc b/tensorflow/contrib/image/kernels/image_ops.cc
index 8d50541771..8a97f07732 100644
--- a/tensorflow/contrib/image/kernels/image_ops.cc
+++ b/tensorflow/contrib/image/kernels/image_ops.cc
@@ -43,13 +43,29 @@ template class FillProjectiveTransform<CPUDevice, double>;
typedef Eigen::ThreadPoolDevice CPUDevice;
using functor::FillProjectiveTransform;
+using generator::INTERPOLATION_BILINEAR;
+using generator::INTERPOLATION_NEAREST;
+using generator::Interpolation;
using generator::ProjectiveGenerator;
template <typename Device, typename T>
class ImageProjectiveTransform : public OpKernel {
+ private:
+ Interpolation interpolation_;
+
public:
- explicit ImageProjectiveTransform(OpKernelConstruction* ctx)
- : OpKernel(ctx) {}
+ explicit ImageProjectiveTransform(OpKernelConstruction* ctx) : OpKernel(ctx) {
+ string interpolation_str;
+ OP_REQUIRES_OK(ctx, ctx->GetAttr("interpolation", &interpolation_str));
+ if (interpolation_str == "NEAREST") {
+ interpolation_ = INTERPOLATION_NEAREST;
+ } else if (interpolation_str == "BILINEAR") {
+ interpolation_ = INTERPOLATION_BILINEAR;
+ } else {
+ LOG(FATAL) << "Invalid interpolation " << interpolation_str
+ << ". Supported types: NEAREST, BILINEAR";
+ }
+ }
void Compute(OpKernelContext* ctx) override {
const Tensor& images_t = ctx->input(0);
@@ -68,8 +84,8 @@ class ImageProjectiveTransform : public OpKernel {
Tensor* output_t;
OP_REQUIRES_OK(ctx, ctx->allocate_output(0, images_t.shape(), &output_t));
auto output = output_t->tensor<T, 4>();
- const FillProjectiveTransform<Device, T> functor;
- functor(ctx->eigen_device<Device>(), &output, images, transform);
+ (FillProjectiveTransform<Device, T>(interpolation_))(
+ ctx->eigen_device<Device>(), &output, images, transform);
}
};