diff options
Diffstat (limited to 'tensorflow/core/kernels/crop_and_resize_op.cc')
-rw-r--r-- | tensorflow/core/kernels/crop_and_resize_op.cc | 151 |
1 files changed, 95 insertions, 56 deletions
diff --git a/tensorflow/core/kernels/crop_and_resize_op.cc b/tensorflow/core/kernels/crop_and_resize_op.cc index 54ef9c6fb4..99d01b4db6 100644 --- a/tensorflow/core/kernels/crop_and_resize_op.cc +++ b/tensorflow/core/kernels/crop_and_resize_op.cc @@ -110,10 +110,10 @@ class CropAndResizeOp : public AsyncOpKernel { public: explicit CropAndResizeOp(OpKernelConstruction* context) : AsyncOpKernel(context) { - string method; - OP_REQUIRES_OK(context, context->GetAttr("method", &method)); - OP_REQUIRES(context, method == "bilinear", - errors::InvalidArgument("method must be 'bilinear'", method)); + OP_REQUIRES_OK(context, context->GetAttr("method", &method_)); + OP_REQUIRES(context, method_ == "bilinear" || method_ == "nearest", + errors::InvalidArgument( + "method must be 'bilinear' or 'nearest'", method_)); OP_REQUIRES_OK(context, context->GetAttr("extrapolation_value", &extrapolation_value_)); } @@ -178,7 +178,7 @@ class CropAndResizeOp : public AsyncOpKernel { const Tensor& box_index = context->input(2); const bool status = functor::CropAndResize<Device, T>()( context, image.tensor<T, 4>(), boxes.tensor<float, 2>(), - box_index.tensor<int32, 1>(), extrapolation_value_, + box_index.tensor<int32, 1>(), method_, extrapolation_value_, output->tensor<float, 4>()); if (!status) { context->SetStatus( @@ -193,6 +193,7 @@ class CropAndResizeOp : public AsyncOpKernel { private: float extrapolation_value_; + string method_; }; // Partial specialization of CropAndResize functor for a CPUDevice. @@ -203,7 +204,7 @@ struct CropAndResize<CPUDevice, T> { typename TTypes<T, 4>::ConstTensor image, typename TTypes<float, 2>::ConstTensor boxes, typename TTypes<int32, 1>::ConstTensor box_index, - float extrapolation_value, + const string& method_name, float extrapolation_value, typename TTypes<float, 4>::Tensor crops) { const int batch_size = image.dimension(0); const int image_height = image.dimension(1); @@ -247,37 +248,57 @@ struct CropAndResize<CPUDevice, T> { } continue; } - const int top_y_index = floorf(in_y); - const int bottom_y_index = ceilf(in_y); - const float y_lerp = in_y - top_y_index; - - for (int x = 0; x < crop_width; ++x) { - const float in_x = (crop_width > 1) - ? x1 * (image_width - 1) + x * width_scale - : 0.5 * (x1 + x2) * (image_width - 1); - if (in_x < 0 || in_x > image_width - 1) { + if (method_name == "bilinear") { + const int top_y_index = floorf(in_y); + const int bottom_y_index = ceilf(in_y); + const float y_lerp = in_y - top_y_index; + + for (int x = 0; x < crop_width; ++x) { + const float in_x = (crop_width > 1) + ? x1 * (image_width - 1) + x * width_scale + : 0.5 * (x1 + x2) * (image_width - 1); + if (in_x < 0 || in_x > image_width - 1) { + for (int d = 0; d < depth; ++d) { + crops(b, y, x, d) = extrapolation_value; + } + continue; + } + const int left_x_index = floorf(in_x); + const int right_x_index = ceilf(in_x); + const float x_lerp = in_x - left_x_index; + for (int d = 0; d < depth; ++d) { - crops(b, y, x, d) = extrapolation_value; + const float top_left(static_cast<float>( + image(b_in, top_y_index, left_x_index, d))); + const float top_right(static_cast<float>( + image(b_in, top_y_index, right_x_index, d))); + const float bottom_left(static_cast<float>( + image(b_in, bottom_y_index, left_x_index, d))); + const float bottom_right(static_cast<float>( + image(b_in, bottom_y_index, right_x_index, d))); + const float top = top_left + (top_right - top_left) * x_lerp; + const float bottom = + bottom_left + (bottom_right - bottom_left) * x_lerp; + crops(b, y, x, d) = top + (bottom - top) * y_lerp; } - continue; } - const int left_x_index = floorf(in_x); - const int right_x_index = ceilf(in_x); - const float x_lerp = in_x - left_x_index; - - for (int d = 0; d < depth; ++d) { - const float top_left(static_cast<float>( - image(b_in, top_y_index, left_x_index, d))); - const float top_right(static_cast<float>( - image(b_in, top_y_index, right_x_index, d))); - const float bottom_left(static_cast<float>( - image(b_in, bottom_y_index, left_x_index, d))); - const float bottom_right(static_cast<float>( - image(b_in, bottom_y_index, right_x_index, d))); - const float top = top_left + (top_right - top_left) * x_lerp; - const float bottom = - bottom_left + (bottom_right - bottom_left) * x_lerp; - crops(b, y, x, d) = top + (bottom - top) * y_lerp; + } else { // method == "nearest" + for (int x = 0; x < crop_width; ++x) { + const float in_x = (crop_width > 1) + ? x1 * (image_width - 1) + x * width_scale + : 0.5 * (x1 + x2) * (image_width - 1); + if (in_x < 0 || in_x > image_width - 1) { + for (int d = 0; d < depth; ++d) { + crops(b, y, x, d) = extrapolation_value; + } + continue; + } + const int closest_x_index = roundf(in_x); + const int closest_y_index = roundf(in_y); + for (int d = 0; d < depth; ++d) { + crops(b, y, x, d) = static_cast<float>( + image(b_in, closest_y_index, closest_x_index, d)); + } } } } @@ -285,12 +306,17 @@ struct CropAndResize<CPUDevice, T> { }; // A rough estimation of the cost for each cropped box. - const double cost_per_pixel = + double cost_per_pixel = depth * (Eigen::TensorOpCost::AddCost<float>() * 6 + Eigen::TensorOpCost::MulCost<float>() * 3 + Eigen::TensorOpCost::CastCost<T, float>() * 4) + (Eigen::TensorOpCost::AddCost<float>() * 2 + Eigen::TensorOpCost::AddCost<float>() * 3); + if (method_name == "nearest") { + cost_per_pixel = depth * Eigen::TensorOpCost::CastCost<T, float>() + + Eigen::TensorOpCost::AddCost<float>() * 4 + + Eigen::TensorOpCost::MulCost<float>() * 4; + } const double cost_per_box = crop_height * crop_width * cost_per_pixel; const DeviceBase::CpuWorkerThreads& worker_threads = @@ -309,10 +335,10 @@ class CropAndResizeGradImageOp : public AsyncOpKernel { public: explicit CropAndResizeGradImageOp(OpKernelConstruction* context) : AsyncOpKernel(context) { - string method; - OP_REQUIRES_OK(context, context->GetAttr("method", &method)); - OP_REQUIRES(context, method == "bilinear", - errors::InvalidArgument("method must be 'bilinear'", method)); + OP_REQUIRES_OK(context, context->GetAttr("method", &method_)); + OP_REQUIRES(context, method_ == "bilinear" || method_ == "nearest", + errors::InvalidArgument( + "method must be 'bilinear' or 'nearest'", method_)); } void ComputeAsync(OpKernelContext* context, DoneCallback done) override { @@ -372,14 +398,14 @@ class CropAndResizeGradImageOp : public AsyncOpKernel { &output), done); - auto compute_callback = [context, output]() { + auto compute_callback = [this, context, output]() { const Tensor& grads = context->input(0); const Tensor& boxes = context->input(1); const Tensor& box_index = context->input(2); const bool status = functor::CropAndResizeBackpropImage<Device, T>()( context->eigen_device<Device>(), grads.tensor<float, 4>(), boxes.tensor<float, 2>(), box_index.tensor<int32, 1>(), - output->tensor<T, 4>()); + output->tensor<T, 4>(), method_); if (!status) { context->SetStatus(errors::Internal( "Failed launch CropAndResizeBackpropImage kernel.")); @@ -390,6 +416,9 @@ class CropAndResizeGradImageOp : public AsyncOpKernel { batch_size, std::move(compute_callback), std::move(done)); } + + private: + string method_; }; // Partial specialization of CropAndResizeBackpropImage functor for a CPUDevice. @@ -400,7 +429,8 @@ struct CropAndResizeBackpropImage<CPUDevice, T> { typename TTypes<float, 4>::ConstTensor grads, typename TTypes<float, 2>::ConstTensor boxes, typename TTypes<int32, 1>::ConstTensor box_index, - typename TTypes<T, 4>::Tensor grads_image) { + typename TTypes<T, 4>::Tensor grads_image, + const string& method_name) { const int batch_size = grads_image.dimension(0); const int image_height = grads_image.dimension(1); const int image_width = grads_image.dimension(2); @@ -448,21 +478,30 @@ struct CropAndResizeBackpropImage<CPUDevice, T> { if (in_x < 0 || in_x > image_width - 1) { continue; } - const int left_x_index = floorf(in_x); - const int right_x_index = ceilf(in_x); - const float x_lerp = in_x - left_x_index; + if (method_name == "bilinear") { + const int left_x_index = floorf(in_x); + const int right_x_index = ceilf(in_x); + const float x_lerp = in_x - left_x_index; - for (int d = 0; d < depth; ++d) { - const float dtop = (1 - y_lerp) * grads(b, y, x, d); - grads_image(b_in, top_y_index, left_x_index, d) += - static_cast<T>((1 - x_lerp) * dtop); - grads_image(b_in, top_y_index, right_x_index, d) += - static_cast<T>(x_lerp * dtop); - const float dbottom = y_lerp * grads(b, y, x, d); - grads_image(b_in, bottom_y_index, left_x_index, d) += - static_cast<T>((1 - x_lerp) * dbottom); - grads_image(b_in, bottom_y_index, right_x_index, d) += - static_cast<T>(x_lerp * dbottom); + for (int d = 0; d < depth; ++d) { + const float dtop = (1 - y_lerp) * grads(b, y, x, d); + grads_image(b_in, top_y_index, left_x_index, d) += + static_cast<T>((1 - x_lerp) * dtop); + grads_image(b_in, top_y_index, right_x_index, d) += + static_cast<T>(x_lerp * dtop); + const float dbottom = y_lerp * grads(b, y, x, d); + grads_image(b_in, bottom_y_index, left_x_index, d) += + static_cast<T>((1 - x_lerp) * dbottom); + grads_image(b_in, bottom_y_index, right_x_index, d) += + static_cast<T>(x_lerp * dbottom); + } + } else { // method_name == "nearest" + for (int d = 0; d < depth; ++d) { + int closest_x_index = roundf(in_x); + int closest_y_index = roundf(in_y); + grads_image(b_in, closest_y_index, closest_x_index, d) += + static_cast<T>(grads(b, y, x, d)); + } } } } |