aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/core/kernels/crop_and_resize_op.cc
diff options
context:
space:
mode:
Diffstat (limited to 'tensorflow/core/kernels/crop_and_resize_op.cc')
-rw-r--r--tensorflow/core/kernels/crop_and_resize_op.cc151
1 files changed, 95 insertions, 56 deletions
diff --git a/tensorflow/core/kernels/crop_and_resize_op.cc b/tensorflow/core/kernels/crop_and_resize_op.cc
index 54ef9c6fb4..99d01b4db6 100644
--- a/tensorflow/core/kernels/crop_and_resize_op.cc
+++ b/tensorflow/core/kernels/crop_and_resize_op.cc
@@ -110,10 +110,10 @@ class CropAndResizeOp : public AsyncOpKernel {
public:
explicit CropAndResizeOp(OpKernelConstruction* context)
: AsyncOpKernel(context) {
- string method;
- OP_REQUIRES_OK(context, context->GetAttr("method", &method));
- OP_REQUIRES(context, method == "bilinear",
- errors::InvalidArgument("method must be 'bilinear'", method));
+ OP_REQUIRES_OK(context, context->GetAttr("method", &method_));
+ OP_REQUIRES(context, method_ == "bilinear" || method_ == "nearest",
+ errors::InvalidArgument(
+ "method must be 'bilinear' or 'nearest'", method_));
OP_REQUIRES_OK(context, context->GetAttr("extrapolation_value",
&extrapolation_value_));
}
@@ -178,7 +178,7 @@ class CropAndResizeOp : public AsyncOpKernel {
const Tensor& box_index = context->input(2);
const bool status = functor::CropAndResize<Device, T>()(
context, image.tensor<T, 4>(), boxes.tensor<float, 2>(),
- box_index.tensor<int32, 1>(), extrapolation_value_,
+ box_index.tensor<int32, 1>(), method_, extrapolation_value_,
output->tensor<float, 4>());
if (!status) {
context->SetStatus(
@@ -193,6 +193,7 @@ class CropAndResizeOp : public AsyncOpKernel {
private:
float extrapolation_value_;
+ string method_;
};
// Partial specialization of CropAndResize functor for a CPUDevice.
@@ -203,7 +204,7 @@ struct CropAndResize<CPUDevice, T> {
typename TTypes<T, 4>::ConstTensor image,
typename TTypes<float, 2>::ConstTensor boxes,
typename TTypes<int32, 1>::ConstTensor box_index,
- float extrapolation_value,
+ const string& method_name, float extrapolation_value,
typename TTypes<float, 4>::Tensor crops) {
const int batch_size = image.dimension(0);
const int image_height = image.dimension(1);
@@ -247,37 +248,57 @@ struct CropAndResize<CPUDevice, T> {
}
continue;
}
- const int top_y_index = floorf(in_y);
- const int bottom_y_index = ceilf(in_y);
- const float y_lerp = in_y - top_y_index;
-
- for (int x = 0; x < crop_width; ++x) {
- const float in_x = (crop_width > 1)
- ? x1 * (image_width - 1) + x * width_scale
- : 0.5 * (x1 + x2) * (image_width - 1);
- if (in_x < 0 || in_x > image_width - 1) {
+ if (method_name == "bilinear") {
+ const int top_y_index = floorf(in_y);
+ const int bottom_y_index = ceilf(in_y);
+ const float y_lerp = in_y - top_y_index;
+
+ for (int x = 0; x < crop_width; ++x) {
+ const float in_x = (crop_width > 1)
+ ? x1 * (image_width - 1) + x * width_scale
+ : 0.5 * (x1 + x2) * (image_width - 1);
+ if (in_x < 0 || in_x > image_width - 1) {
+ for (int d = 0; d < depth; ++d) {
+ crops(b, y, x, d) = extrapolation_value;
+ }
+ continue;
+ }
+ const int left_x_index = floorf(in_x);
+ const int right_x_index = ceilf(in_x);
+ const float x_lerp = in_x - left_x_index;
+
for (int d = 0; d < depth; ++d) {
- crops(b, y, x, d) = extrapolation_value;
+ const float top_left(static_cast<float>(
+ image(b_in, top_y_index, left_x_index, d)));
+ const float top_right(static_cast<float>(
+ image(b_in, top_y_index, right_x_index, d)));
+ const float bottom_left(static_cast<float>(
+ image(b_in, bottom_y_index, left_x_index, d)));
+ const float bottom_right(static_cast<float>(
+ image(b_in, bottom_y_index, right_x_index, d)));
+ const float top = top_left + (top_right - top_left) * x_lerp;
+ const float bottom =
+ bottom_left + (bottom_right - bottom_left) * x_lerp;
+ crops(b, y, x, d) = top + (bottom - top) * y_lerp;
}
- continue;
}
- const int left_x_index = floorf(in_x);
- const int right_x_index = ceilf(in_x);
- const float x_lerp = in_x - left_x_index;
-
- for (int d = 0; d < depth; ++d) {
- const float top_left(static_cast<float>(
- image(b_in, top_y_index, left_x_index, d)));
- const float top_right(static_cast<float>(
- image(b_in, top_y_index, right_x_index, d)));
- const float bottom_left(static_cast<float>(
- image(b_in, bottom_y_index, left_x_index, d)));
- const float bottom_right(static_cast<float>(
- image(b_in, bottom_y_index, right_x_index, d)));
- const float top = top_left + (top_right - top_left) * x_lerp;
- const float bottom =
- bottom_left + (bottom_right - bottom_left) * x_lerp;
- crops(b, y, x, d) = top + (bottom - top) * y_lerp;
+ } else { // method == "nearest"
+ for (int x = 0; x < crop_width; ++x) {
+ const float in_x = (crop_width > 1)
+ ? x1 * (image_width - 1) + x * width_scale
+ : 0.5 * (x1 + x2) * (image_width - 1);
+ if (in_x < 0 || in_x > image_width - 1) {
+ for (int d = 0; d < depth; ++d) {
+ crops(b, y, x, d) = extrapolation_value;
+ }
+ continue;
+ }
+ const int closest_x_index = roundf(in_x);
+ const int closest_y_index = roundf(in_y);
+ for (int d = 0; d < depth; ++d) {
+ crops(b, y, x, d) = static_cast<float>(
+ image(b_in, closest_y_index, closest_x_index, d));
+ }
}
}
}
@@ -285,12 +306,17 @@ struct CropAndResize<CPUDevice, T> {
};
// A rough estimation of the cost for each cropped box.
- const double cost_per_pixel =
+ double cost_per_pixel =
depth * (Eigen::TensorOpCost::AddCost<float>() * 6 +
Eigen::TensorOpCost::MulCost<float>() * 3 +
Eigen::TensorOpCost::CastCost<T, float>() * 4) +
(Eigen::TensorOpCost::AddCost<float>() * 2 +
Eigen::TensorOpCost::AddCost<float>() * 3);
+ if (method_name == "nearest") {
+ cost_per_pixel = depth * Eigen::TensorOpCost::CastCost<T, float>() +
+ Eigen::TensorOpCost::AddCost<float>() * 4 +
+ Eigen::TensorOpCost::MulCost<float>() * 4;
+ }
const double cost_per_box = crop_height * crop_width * cost_per_pixel;
const DeviceBase::CpuWorkerThreads& worker_threads =
@@ -309,10 +335,10 @@ class CropAndResizeGradImageOp : public AsyncOpKernel {
public:
explicit CropAndResizeGradImageOp(OpKernelConstruction* context)
: AsyncOpKernel(context) {
- string method;
- OP_REQUIRES_OK(context, context->GetAttr("method", &method));
- OP_REQUIRES(context, method == "bilinear",
- errors::InvalidArgument("method must be 'bilinear'", method));
+ OP_REQUIRES_OK(context, context->GetAttr("method", &method_));
+ OP_REQUIRES(context, method_ == "bilinear" || method_ == "nearest",
+ errors::InvalidArgument(
+ "method must be 'bilinear' or 'nearest'", method_));
}
void ComputeAsync(OpKernelContext* context, DoneCallback done) override {
@@ -372,14 +398,14 @@ class CropAndResizeGradImageOp : public AsyncOpKernel {
&output),
done);
- auto compute_callback = [context, output]() {
+ auto compute_callback = [this, context, output]() {
const Tensor& grads = context->input(0);
const Tensor& boxes = context->input(1);
const Tensor& box_index = context->input(2);
const bool status = functor::CropAndResizeBackpropImage<Device, T>()(
context->eigen_device<Device>(), grads.tensor<float, 4>(),
boxes.tensor<float, 2>(), box_index.tensor<int32, 1>(),
- output->tensor<T, 4>());
+ output->tensor<T, 4>(), method_);
if (!status) {
context->SetStatus(errors::Internal(
"Failed launch CropAndResizeBackpropImage kernel."));
@@ -390,6 +416,9 @@ class CropAndResizeGradImageOp : public AsyncOpKernel {
batch_size, std::move(compute_callback),
std::move(done));
}
+
+ private:
+ string method_;
};
// Partial specialization of CropAndResizeBackpropImage functor for a CPUDevice.
@@ -400,7 +429,8 @@ struct CropAndResizeBackpropImage<CPUDevice, T> {
typename TTypes<float, 4>::ConstTensor grads,
typename TTypes<float, 2>::ConstTensor boxes,
typename TTypes<int32, 1>::ConstTensor box_index,
- typename TTypes<T, 4>::Tensor grads_image) {
+ typename TTypes<T, 4>::Tensor grads_image,
+ const string& method_name) {
const int batch_size = grads_image.dimension(0);
const int image_height = grads_image.dimension(1);
const int image_width = grads_image.dimension(2);
@@ -448,21 +478,30 @@ struct CropAndResizeBackpropImage<CPUDevice, T> {
if (in_x < 0 || in_x > image_width - 1) {
continue;
}
- const int left_x_index = floorf(in_x);
- const int right_x_index = ceilf(in_x);
- const float x_lerp = in_x - left_x_index;
+ if (method_name == "bilinear") {
+ const int left_x_index = floorf(in_x);
+ const int right_x_index = ceilf(in_x);
+ const float x_lerp = in_x - left_x_index;
- for (int d = 0; d < depth; ++d) {
- const float dtop = (1 - y_lerp) * grads(b, y, x, d);
- grads_image(b_in, top_y_index, left_x_index, d) +=
- static_cast<T>((1 - x_lerp) * dtop);
- grads_image(b_in, top_y_index, right_x_index, d) +=
- static_cast<T>(x_lerp * dtop);
- const float dbottom = y_lerp * grads(b, y, x, d);
- grads_image(b_in, bottom_y_index, left_x_index, d) +=
- static_cast<T>((1 - x_lerp) * dbottom);
- grads_image(b_in, bottom_y_index, right_x_index, d) +=
- static_cast<T>(x_lerp * dbottom);
+ for (int d = 0; d < depth; ++d) {
+ const float dtop = (1 - y_lerp) * grads(b, y, x, d);
+ grads_image(b_in, top_y_index, left_x_index, d) +=
+ static_cast<T>((1 - x_lerp) * dtop);
+ grads_image(b_in, top_y_index, right_x_index, d) +=
+ static_cast<T>(x_lerp * dtop);
+ const float dbottom = y_lerp * grads(b, y, x, d);
+ grads_image(b_in, bottom_y_index, left_x_index, d) +=
+ static_cast<T>((1 - x_lerp) * dbottom);
+ grads_image(b_in, bottom_y_index, right_x_index, d) +=
+ static_cast<T>(x_lerp * dbottom);
+ }
+ } else { // method_name == "nearest"
+ for (int d = 0; d < depth; ++d) {
+ int closest_x_index = roundf(in_x);
+ int closest_y_index = roundf(in_y);
+ grads_image(b_in, closest_y_index, closest_x_index, d) +=
+ static_cast<T>(grads(b, y, x, d));
+ }
}
}
}