aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/core/kernels/crop_and_resize_op.cc
diff options
context:
space:
mode:
Diffstat (limited to 'tensorflow/core/kernels/crop_and_resize_op.cc')
-rw-r--r--tensorflow/core/kernels/crop_and_resize_op.cc565
1 files changed, 321 insertions, 244 deletions
diff --git a/tensorflow/core/kernels/crop_and_resize_op.cc b/tensorflow/core/kernels/crop_and_resize_op.cc
index 746fe63e2a..1c7afcf866 100644
--- a/tensorflow/core/kernels/crop_and_resize_op.cc
+++ b/tensorflow/core/kernels/crop_and_resize_op.cc
@@ -19,6 +19,9 @@ limitations under the License.
#include "tensorflow/core/kernels/crop_and_resize_op.h"
+#include <functional>
+#include <string>
+
#include "third_party/eigen3/unsupported/Eigen/CXX11/Tensor"
#include "tensorflow/core/framework/op_kernel.h"
#include "tensorflow/core/framework/register_types.h"
@@ -26,10 +29,13 @@ limitations under the License.
#include "tensorflow/core/framework/tensor_shape.h"
#include "tensorflow/core/framework/types.h"
#include "tensorflow/core/kernels/bounds_check.h"
+#include "tensorflow/core/lib/core/errors.h"
#include "tensorflow/core/lib/core/status.h"
#include "tensorflow/core/platform/logging.h"
+#include "tensorflow/core/platform/types.h"
#if GOOGLE_CUDA
+#include "tensorflow/core/common_runtime/gpu/gpu_event_mgr.h"
#include "tensorflow/core/platform/stream_executor.h"
#endif // GOOGLE_CUDA
@@ -37,41 +43,67 @@ namespace tensorflow {
typedef Eigen::ThreadPoolDevice CPUDevice;
typedef Eigen::GpuDevice GPUDevice;
+using Callback = std::function<void()>;
+
+namespace {
-static inline void ParseAndCheckBoxSizes(OpKernelContext* context,
- const Tensor& boxes,
- const Tensor& box_ind,
- int* num_boxes) {
- if (boxes.NumElements() == 0 && box_ind.NumElements() == 0) {
+static inline Status ParseAndCheckBoxSizes(const Tensor& boxes,
+ const Tensor& box_index,
+ int* num_boxes) {
+ if (boxes.NumElements() == 0 && box_index.NumElements() == 0) {
*num_boxes = 0;
- return;
+ return Status::OK();
}
// The shape of 'boxes' is [num_boxes, 4].
- OP_REQUIRES(context, boxes.dims() == 2,
- errors::InvalidArgument("boxes must be 2-D",
- boxes.shape().DebugString()));
+ if (boxes.dims() != 2) {
+ return errors::InvalidArgument("boxes must be 2-D",
+ boxes.shape().DebugString());
+ }
*num_boxes = boxes.dim_size(0);
- OP_REQUIRES(context, boxes.dim_size(1) == 4,
- errors::InvalidArgument("boxes must have 4 columns"));
-
- // The shape of 'box_ind' is [num_boxes].
- OP_REQUIRES(context, box_ind.dims() == 1,
- errors::InvalidArgument("box_ind must be 1-D",
- box_ind.shape().DebugString()));
- OP_REQUIRES(context, box_ind.dim_size(0) == *num_boxes,
- errors::InvalidArgument("box_ind has incompatible shape"));
+ if (boxes.dim_size(1) != 4) {
+ return errors::InvalidArgument("boxes must have 4 columns");
+ }
+ // The shape of 'box_index' is [num_boxes].
+ if (box_index.dims() != 1) {
+ return errors::InvalidArgument("box_index must be 1-D",
+ box_index.shape().DebugString());
+ }
+ if (box_index.dim_size(0) != *num_boxes) {
+ return errors::InvalidArgument("box_index has incompatible shape");
+ }
+ return Status::OK();
}
-// Verifies that all values in box_ind are in [0, batch).
+// Conditionally calls the compute callback if all values in box_index are in
+// [0, batch_size) then calls done.
template <typename Device>
-inline void CheckValidBoxInd(
- OpKernelContext* context,
- typename TTypes<int32, 1>::ConstTensor box_ind_data, int batch);
+inline void RunIfBoxIndexIsValid(
+ OpKernelContext* context, typename TTypes<int32, 1>::ConstTensor box_index,
+ int batch_size, Callback compute, Callback done);
+
+// Specialization of CheckValidBoxIndex for a CPUDevice.
+template <>
+inline void RunIfBoxIndexIsValid<CPUDevice>(
+ OpKernelContext* context, typename TTypes<int32, 1>::ConstTensor box_index,
+ int batch_size, Callback compute, Callback done) {
+ const int num_boxes = box_index.dimension(0);
+ for (int b = 0; b < num_boxes; ++b) {
+ OP_REQUIRES_ASYNC(
+ context, FastBoundsCheck(box_index(b), batch_size),
+ errors::OutOfRange("box_index has values outside [0, batch_size)"),
+ done);
+ }
+ compute();
+ done();
+}
+
+} // namespace
template <typename Device, typename T>
-class CropAndResizeOp : public OpKernel {
+class CropAndResizeOp : public AsyncOpKernel {
public:
- explicit CropAndResizeOp(OpKernelConstruction* context) : OpKernel(context) {
+ explicit CropAndResizeOp(OpKernelConstruction* context)
+ : AsyncOpKernel(context) {
string method;
OP_REQUIRES_OK(context, context->GetAttr("method", &method));
OP_REQUIRES(context, method == "bilinear",
@@ -80,69 +112,77 @@ class CropAndResizeOp : public OpKernel {
&extrapolation_value_));
}
- void Compute(OpKernelContext* context) override {
- // The shape of 'image' is [batch, image_height, image_width, channels].
+ void ComputeAsync(OpKernelContext* context, DoneCallback done) override {
+ // The shape of 'image' is [batch_size, image_height, image_width,
+ // channels].
const Tensor& image = context->input(0);
- OP_REQUIRES(context, image.dims() == 4,
- errors::InvalidArgument("input image must be 4-D",
- image.shape().DebugString()));
-
- const int batch = image.dim_size(0);
- const int image_height = image.dim_size(1);
- const int image_width = image.dim_size(2);
- const int depth = image.dim_size(3);
- OP_REQUIRES(context, image_height > 0 && image_width > 0,
- errors::InvalidArgument("image dimensions must be positive"));
-
// The shape of 'boxes' is [num_boxes, 4].
const Tensor& boxes = context->input(1);
-
- // The shape of 'box_ind' is [num_boxes].
- const Tensor& box_ind = context->input(2);
-
- int num_boxes = 0;
- ParseAndCheckBoxSizes(context, boxes, box_ind, &num_boxes);
-
+ // The shape of 'box_index' is [num_boxes].
+ const Tensor& box_index = context->input(2);
// The shape of 'crop_size' is [2].
const Tensor& crop_size = context->input(3);
- OP_REQUIRES(context, crop_size.dims() == 1,
- errors::InvalidArgument("crop_size must be 1-D",
- crop_size.shape().DebugString()));
- OP_REQUIRES(context, crop_size.dim_size(0) == 2,
- errors::InvalidArgument("crop_size must have two elements",
- crop_size.shape().DebugString()));
-
+ // Validate inputs dimensions.
+ OP_REQUIRES_ASYNC(context, image.dims() == 4,
+ errors::InvalidArgument("input image must be 4-D",
+ image.shape().DebugString()),
+ done);
+ const int batch_size = image.dim_size(0);
+ const int image_height = image.dim_size(1);
+ const int image_width = image.dim_size(2);
+ const int depth = image.dim_size(3);
+ OP_REQUIRES_ASYNC(
+ context, image_height > 0 && image_width > 0,
+ errors::InvalidArgument("image dimensions must be positive"), done);
+ int num_boxes = 0;
+ OP_REQUIRES_OK_ASYNC(
+ context, ParseAndCheckBoxSizes(boxes, box_index, &num_boxes), done);
+
+ OP_REQUIRES_ASYNC(context, crop_size.dims() == 1,
+ errors::InvalidArgument("crop_size must be 1-D",
+ crop_size.shape().DebugString()),
+ done);
+ OP_REQUIRES_ASYNC(
+ context, crop_size.dim_size(0) == 2,
+ errors::InvalidArgument("crop_size must have two elements",
+ crop_size.shape().DebugString()),
+ done);
+
+ // Copy and validate crop sizes.
auto crop_size_vec = crop_size.vec<int32>();
const int crop_height = internal::SubtleMustCopy(crop_size_vec(0));
const int crop_width = internal::SubtleMustCopy(crop_size_vec(1));
- OP_REQUIRES(context, crop_height > 0 && crop_width > 0,
- errors::InvalidArgument("crop dimensions must be positive"));
+ OP_REQUIRES_ASYNC(
+ context, crop_height > 0 && crop_width > 0,
+ errors::InvalidArgument("crop dimensions must be positive"), done);
// Allocate output tensor.
Tensor* output = nullptr;
- OP_REQUIRES_OK(
+ OP_REQUIRES_OK_ASYNC(
context,
context->allocate_output(
0, TensorShape({num_boxes, crop_height, crop_width, depth}),
- &output));
-
- typename TTypes<T, 4>::ConstTensor image_data = image.tensor<T, 4>();
- typename TTypes<float, 2>::ConstTensor boxes_data =
- boxes.tensor<float, 2>();
- typename TTypes<int32, 1>::ConstTensor box_ind_data =
- box_ind.tensor<int32, 1>();
- typename TTypes<float, 4>::Tensor crops_data = output->tensor<float, 4>();
-
- CheckValidBoxInd<Device>(context, box_ind_data, batch);
-
- bool status = functor::CropAndResize<Device, T>()(
- context->eigen_device<Device>(), image_data, boxes_data, box_ind_data,
- extrapolation_value_, crops_data);
- if (!status) {
- context->SetStatus(
- errors::Internal("Failed launch CropAndResizeKernel."));
- }
+ &output),
+ done);
+
+ auto compute_callback = [this, context, output]() {
+ const Tensor& image = context->input(0);
+ const Tensor& boxes = context->input(1);
+ const Tensor& box_index = context->input(2);
+ const bool status = functor::CropAndResize<Device, T>()(
+ context->eigen_device<Device>(), image.tensor<T, 4>(),
+ boxes.tensor<float, 2>(), box_index.tensor<int32, 1>(),
+ extrapolation_value_, output->tensor<float, 4>());
+ if (!status) {
+ context->SetStatus(
+ errors::Internal("Failed launch CropAndResizeKernel."));
+ }
+ };
+
+ RunIfBoxIndexIsValid<Device>(context, box_index.tensor<int32, 1>(),
+ batch_size, std::move(compute_callback),
+ std::move(done));
}
private:
@@ -155,10 +195,10 @@ template <typename T>
struct CropAndResize<CPUDevice, T> {
bool operator()(const CPUDevice& d, typename TTypes<T, 4>::ConstTensor image,
typename TTypes<float, 2>::ConstTensor boxes,
- typename TTypes<int32, 1>::ConstTensor box_ind,
+ typename TTypes<int32, 1>::ConstTensor box_index,
float extrapolation_value,
typename TTypes<float, 4>::Tensor crops) {
- const int batch = image.dimension(0);
+ const int batch_size = image.dimension(0);
const int image_height = image.dimension(1);
const int image_width = image.dimension(2);
@@ -173,8 +213,8 @@ struct CropAndResize<CPUDevice, T> {
const float y2 = boxes(b, 2);
const float x2 = boxes(b, 3);
- const int32 b_in = box_ind(b);
- if (b_in < 0 || b_in >= batch) {
+ const int32 b_in = box_index(b);
+ if (!FastBoundsCheck(b_in, batch_size)) {
continue;
}
@@ -235,89 +275,94 @@ struct CropAndResize<CPUDevice, T> {
return true;
}
};
+
} // namespace functor
template <typename Device, typename T>
-class CropAndResizeGradImageOp : public OpKernel {
+class CropAndResizeGradImageOp : public AsyncOpKernel {
public:
explicit CropAndResizeGradImageOp(OpKernelConstruction* context)
- : OpKernel(context) {
+ : AsyncOpKernel(context) {
string method;
OP_REQUIRES_OK(context, context->GetAttr("method", &method));
OP_REQUIRES(context, method == "bilinear",
errors::InvalidArgument("method must be 'bilinear'", method));
}
- void Compute(OpKernelContext* context) override {
+ void ComputeAsync(OpKernelContext* context, DoneCallback done) override {
// The shape of 'grads' is [num_boxes, crop_height, crop_width, depth].
const Tensor& grads = context->input(0);
-
- OP_REQUIRES(context, grads.dims() == 4,
- errors::InvalidArgument("grads image must be 4-D",
- grads.shape().DebugString()));
- const int crop_height = grads.dim_size(1);
- const int crop_width = grads.dim_size(2);
- OP_REQUIRES(context, crop_height > 0 && crop_width > 0,
- errors::InvalidArgument("grads dimensions must be positive"));
-
// The shape of 'boxes' is [num_boxes, 4].
const Tensor& boxes = context->input(1);
-
- // The shape of 'box_ind' is [num_boxes].
- const Tensor& box_ind = context->input(2);
-
- int num_boxes = 0;
- ParseAndCheckBoxSizes(context, boxes, box_ind, &num_boxes);
-
- OP_REQUIRES(
- context, grads.dim_size(0) == num_boxes,
- errors::InvalidArgument("boxes and grads have incompatible shape"));
-
+ // The shape of 'box_index' is [num_boxes].
+ const Tensor& box_index = context->input(2);
// The shape of 'image_size' is [4].
const Tensor& image_size = context->input(3);
- OP_REQUIRES(context, image_size.dims() == 1,
- errors::InvalidArgument("image_size must be 1-D",
- image_size.shape().DebugString()));
- OP_REQUIRES(context, image_size.dim_size(0) == 4,
- errors::InvalidArgument("image_size must have 4 elements",
- image_size.shape().DebugString()));
+ // Validate input shapes.
+ OP_REQUIRES_ASYNC(context, grads.dims() == 4,
+ errors::InvalidArgument("grads image must be 4-D",
+ grads.shape().DebugString()),
+ done);
+ const int crop_height = grads.dim_size(1);
+ const int crop_width = grads.dim_size(2);
+ OP_REQUIRES_ASYNC(
+ context, crop_height > 0 && crop_width > 0,
+ errors::InvalidArgument("grads dimensions must be positive"), done);
+ int num_boxes = 0;
+ OP_REQUIRES_OK_ASYNC(
+ context, ParseAndCheckBoxSizes(boxes, box_index, &num_boxes), done);
+ OP_REQUIRES_ASYNC(
+ context, grads.dim_size(0) == num_boxes,
+ errors::InvalidArgument("boxes and grads have incompatible shape"),
+ done);
+
+ OP_REQUIRES_ASYNC(context, image_size.dims() == 1,
+ errors::InvalidArgument("image_size must be 1-D",
+ image_size.shape().DebugString()),
+ done);
+ OP_REQUIRES_ASYNC(context, image_size.dim_size(0) == 4,
+ errors::InvalidArgument("image_size must have 4 elements",
+ image_size.shape().DebugString()),
+ done);
auto image_size_vec = image_size.vec<int32>();
- const int batch = internal::SubtleMustCopy(image_size_vec(0));
+ const int batch_size = internal::SubtleMustCopy(image_size_vec(0));
const int image_height = internal::SubtleMustCopy(image_size_vec(1));
const int image_width = internal::SubtleMustCopy(image_size_vec(2));
const int depth = internal::SubtleMustCopy(image_size_vec(3));
-
- OP_REQUIRES(context, image_height > 0 && image_width > 0,
- errors::InvalidArgument("image dimensions must be positive"));
- OP_REQUIRES(
+ OP_REQUIRES_ASYNC(
+ context, image_height > 0 && image_width > 0,
+ errors::InvalidArgument("image dimensions must be positive"), done);
+ OP_REQUIRES_ASYNC(
context, grads.dim_size(3) == depth,
- errors::InvalidArgument("image_size and grads are incompatible"));
+ errors::InvalidArgument("image_size and grads are incompatible"), done);
// Allocate output tensor.
Tensor* output = nullptr;
- OP_REQUIRES_OK(
- context, context->allocate_output(
- 0, TensorShape({batch, image_height, image_width, depth}),
- &output));
-
- typename TTypes<float, 4>::ConstTensor grads_data =
- grads.tensor<float, 4>();
- typename TTypes<float, 2>::ConstTensor boxes_data =
- boxes.tensor<float, 2>();
- typename TTypes<int32, 1>::ConstTensor box_ind_data =
- box_ind.tensor<int32, 1>();
- typename TTypes<T, 4>::Tensor output_data = output->tensor<T, 4>();
-
- CheckValidBoxInd<Device>(context, box_ind_data, batch);
-
- bool status = functor::CropAndResizeBackpropImage<Device, T>()(
- context->eigen_device<Device>(), grads_data, boxes_data, box_ind_data,
- output_data);
- if (!status) {
- context->SetStatus(
- errors::Internal("Failed launch CropAndResizeBackpropImageKernel."));
- }
+ OP_REQUIRES_OK_ASYNC(
+ context,
+ context->allocate_output(
+ 0, TensorShape({batch_size, image_height, image_width, depth}),
+ &output),
+ done);
+
+ auto compute_callback = [context, output]() {
+ const Tensor& grads = context->input(0);
+ const Tensor& boxes = context->input(1);
+ const Tensor& box_index = context->input(2);
+ const bool status = functor::CropAndResizeBackpropImage<Device, T>()(
+ context->eigen_device<Device>(), grads.tensor<float, 4>(),
+ boxes.tensor<float, 2>(), box_index.tensor<int32, 1>(),
+ output->tensor<T, 4>());
+ if (!status) {
+ context->SetStatus(errors::Internal(
+ "Failed launch CropAndResizeBackpropImage kernel."));
+ }
+ };
+
+ RunIfBoxIndexIsValid<Device>(context, box_index.tensor<int32, 1>(),
+ batch_size, std::move(compute_callback),
+ std::move(done));
}
};
@@ -328,9 +373,9 @@ struct CropAndResizeBackpropImage<CPUDevice, T> {
bool operator()(const CPUDevice& d,
typename TTypes<float, 4>::ConstTensor grads,
typename TTypes<float, 2>::ConstTensor boxes,
- typename TTypes<int32, 1>::ConstTensor box_ind,
+ typename TTypes<int32, 1>::ConstTensor box_index,
typename TTypes<T, 4>::Tensor grads_image) {
- const int batch = grads_image.dimension(0);
+ const int batch_size = grads_image.dimension(0);
const int image_height = grads_image.dimension(1);
const int image_width = grads_image.dimension(2);
@@ -347,8 +392,8 @@ struct CropAndResizeBackpropImage<CPUDevice, T> {
const float y2 = boxes(b, 2);
const float x2 = boxes(b, 3);
- const int32 b_in = box_ind(b);
- if (b_in < 0 || b_in >= batch) {
+ const int32 b_in = box_index(b);
+ if (!FastBoundsCheck(b_in, batch_size)) {
continue;
}
@@ -399,83 +444,90 @@ struct CropAndResizeBackpropImage<CPUDevice, T> {
return true;
}
};
+
} // namespace functor
template <typename Device, typename T>
-class CropAndResizeGradBoxesOp : public OpKernel {
+class CropAndResizeGradBoxesOp : public AsyncOpKernel {
public:
explicit CropAndResizeGradBoxesOp(OpKernelConstruction* context)
- : OpKernel(context) {
+ : AsyncOpKernel(context) {
string method;
OP_REQUIRES_OK(context, context->GetAttr("method", &method));
OP_REQUIRES(context, method == "bilinear",
errors::InvalidArgument("method must be 'bilinear'", method));
}
- void Compute(OpKernelContext* context) override {
+ void ComputeAsync(OpKernelContext* context, DoneCallback done) override {
// The shape of 'grads' is [num_boxes, crop_height, crop_width, depth].
const Tensor& grads = context->input(0);
+ // The shape of 'boxes' is [num_boxes, 4].
+ const Tensor& boxes = context->input(2);
+ // The shape of 'box_index' is [num_boxes].
+ const Tensor& box_index = context->input(3);
+ // The shape of 'image' is [batch_size, image_height, image_width, depth].
+ const Tensor& image = context->input(1);
- OP_REQUIRES(context, grads.dims() == 4,
- errors::InvalidArgument("grads image must be 4-D",
- grads.shape().DebugString()));
-
+ // Validate input shapes.
+ OP_REQUIRES_ASYNC(context, grads.dims() == 4,
+ errors::InvalidArgument("grads image must be 4-D",
+ grads.shape().DebugString()),
+ done);
const int crop_height = grads.dim_size(1);
const int crop_width = grads.dim_size(2);
const int depth = grads.dim_size(3);
- OP_REQUIRES(context, crop_height > 0 && crop_width > 0,
- errors::InvalidArgument("grads dimensions must be positive"));
-
- // The shape of 'image' is [batch, image_height, image_width, depth].
- const Tensor& image = context->input(1);
- OP_REQUIRES(context, image.dims() == 4,
- errors::InvalidArgument("input image must be 4-D",
- image.shape().DebugString()));
-
- const int batch = image.dim_size(0);
+ OP_REQUIRES_ASYNC(
+ context, crop_height > 0 && crop_width > 0,
+ errors::InvalidArgument("grads dimensions must be positive"), done);
+
+ OP_REQUIRES_ASYNC(context, image.dims() == 4,
+ errors::InvalidArgument("input image must be 4-D",
+ image.shape().DebugString()),
+ done);
+ const int batch_size = image.dim_size(0);
const int image_height = image.dim_size(1);
const int image_width = image.dim_size(2);
- OP_REQUIRES(context, image_height > 0 && image_width > 0,
- errors::InvalidArgument("image dimensions must be positive"));
- OP_REQUIRES(context, image.dim_size(3) == depth,
- errors::InvalidArgument("image, grads depth differ"));
-
- // The shape of 'boxes' is [num_boxes, 4].
- const Tensor& boxes = context->input(2);
-
- // The shape of 'box_ind' is [num_boxes].
- const Tensor& box_ind = context->input(3);
+ OP_REQUIRES_ASYNC(
+ context, image_height > 0 && image_width > 0,
+ errors::InvalidArgument("image dimensions must be positive"), done);
+ OP_REQUIRES_ASYNC(context, image.dim_size(3) == depth,
+ errors::InvalidArgument("image, grads depth differ"),
+ done);
int num_boxes = 0;
- ParseAndCheckBoxSizes(context, boxes, box_ind, &num_boxes);
+ OP_REQUIRES_OK_ASYNC(
+ context, ParseAndCheckBoxSizes(boxes, box_index, &num_boxes), done);
- OP_REQUIRES(
+ OP_REQUIRES_ASYNC(
context, grads.dim_size(0) == num_boxes,
- errors::InvalidArgument("boxes and grads have incompatible shape"));
+ errors::InvalidArgument("boxes and grads have incompatible shape"),
+ done);
// Allocate output tensor.
Tensor* output = nullptr;
- OP_REQUIRES_OK(context, context->allocate_output(
- 0, TensorShape({num_boxes, 4}), &output));
-
- typename TTypes<float, 4>::ConstTensor grads_data =
- grads.tensor<float, 4>();
- typename TTypes<T, 4>::ConstTensor image_data = image.tensor<T, 4>();
- typename TTypes<float, 2>::ConstTensor boxes_data =
- boxes.tensor<float, 2>();
- typename TTypes<int32, 1>::ConstTensor box_ind_data =
- box_ind.tensor<int32, 1>();
- typename TTypes<float, 2>::Tensor output_data = output->tensor<float, 2>();
-
- CheckValidBoxInd<Device>(context, box_ind_data, batch);
-
- bool status = functor::CropAndResizeBackpropBoxes<Device, T>()(
- context->eigen_device<Device>(), grads_data, image_data, boxes_data,
- box_ind_data, output_data);
- if (!status) {
- context->SetStatus(
- errors::Internal("Failed launch CropAndResizeBackpropBoxesKernel."));
- }
+ OP_REQUIRES_OK_ASYNC(
+ context,
+ context->allocate_output(0, TensorShape({num_boxes, 4}), &output),
+ done);
+
+ auto compute_callback = [context, output]() {
+ const Tensor& grads = context->input(0);
+ const Tensor& image = context->input(1);
+ const Tensor& boxes = context->input(2);
+ const Tensor& box_index = context->input(3);
+ const bool status = functor::CropAndResizeBackpropBoxes<Device, T>()(
+ context->eigen_device<Device>(), grads.tensor<float, 4>(),
+ image.tensor<T, 4>(), boxes.tensor<float, 2>(),
+ box_index.tensor<int32, 1>(), output->tensor<float, 2>());
+ if (!status) {
+ context->SetStatus(errors::Internal(
+ "Failed launch CropAndResizeBackpropBoxes kernel."));
+ }
+ };
+
+ RunIfBoxIndexIsValid<Device>(context, box_index.tensor<int32, 1>(),
+ batch_size, std::move(compute_callback),
+ std::move(done));
}
};
@@ -487,9 +539,9 @@ struct CropAndResizeBackpropBoxes<CPUDevice, T> {
typename TTypes<float, 4>::ConstTensor grads,
typename TTypes<T, 4>::ConstTensor image,
typename TTypes<float, 2>::ConstTensor boxes,
- typename TTypes<int32, 1>::ConstTensor box_ind,
+ typename TTypes<int32, 1>::ConstTensor box_index,
typename TTypes<float, 2>::Tensor grads_boxes) {
- const int batch = image.dimension(0);
+ const int batch_size = image.dimension(0);
const int image_height = image.dimension(1);
const int image_width = image.dimension(2);
@@ -506,8 +558,8 @@ struct CropAndResizeBackpropBoxes<CPUDevice, T> {
const float y2 = boxes(b, 2);
const float x2 = boxes(b, 3);
- const int32 b_in = box_ind(b);
- if (b_in < 0 || b_in >= batch) {
+ const int32 b_in = box_index(b);
+ if (!FastBoundsCheck(b_in, batch_size)) {
continue;
}
@@ -589,30 +641,19 @@ struct CropAndResizeBackpropBoxes<CPUDevice, T> {
return true;
}
};
-} // namespace functor
-// Specialization of CheckValidBoxInd for a CPUDevice.
-template <>
-inline void CheckValidBoxInd<CPUDevice>(
- OpKernelContext* context, typename TTypes<int32, 1>::ConstTensor box_ind,
- int batch) {
- const int num_boxes = box_ind.dimension(0);
- for (int b = 0; b < num_boxes; ++b) {
- OP_REQUIRES(context, box_ind(b) >= 0 && box_ind(b) < batch,
- errors::OutOfRange("box_ind has values outside [0, batch)"));
- }
-}
+} // namespace functor
-#define REGISTER_KERNEL(T) \
- REGISTER_KERNEL_BUILDER(Name("CropAndResize") \
- .Device(DEVICE_CPU) \
- .TypeConstraint<T>("T") \
- .HostMemory("crop_size"), \
- CropAndResizeOp<CPUDevice, T>); \
- \
- REGISTER_KERNEL_BUILDER(Name("CropAndResizeGradBoxes") \
- .Device(DEVICE_CPU) \
- .TypeConstraint<T>("T"), \
+#define REGISTER_KERNEL(T) \
+ REGISTER_KERNEL_BUILDER(Name("CropAndResize") \
+ .Device(DEVICE_CPU) \
+ .TypeConstraint<T>("T") \
+ .HostMemory("crop_size"), \
+ CropAndResizeOp<CPUDevice, T>); \
+ \
+ REGISTER_KERNEL_BUILDER(Name("CropAndResizeGradBoxes") \
+ .Device(DEVICE_CPU) \
+ .TypeConstraint<T>("T"), \
CropAndResizeGradBoxesOp<CPUDevice, T>);
TF_CALL_REAL_NUMBER_TYPES(REGISTER_KERNEL);
@@ -634,50 +675,86 @@ TF_CALL_double(REGISTER_KERNEL);
#if GOOGLE_CUDA
-// Forward declaration of the CheckValidBoxIndHelper specialization for GPU.
+// Forward declaration of the CheckValidBoxIndexHelper specialization for GPU.
namespace functor {
template <>
-void CheckValidBoxIndHelper<GPUDevice>::operator()(
- const GPUDevice& d, typename TTypes<int32, 1>::ConstTensor box_ind,
- int batch, typename TTypes<bool, 0>::Tensor isvalid);
-extern template struct CheckValidBoxIndHelper<GPUDevice>;
+void CheckValidBoxIndexHelper<GPUDevice>::operator()(
+ const GPUDevice& d, typename TTypes<int32, 1>::ConstTensor box_index,
+ int batch_size, typename TTypes<bool, 0>::Tensor isvalid);
+extern template struct CheckValidBoxIndexHelper<GPUDevice>;
} // namespace functor
-// Specialization of CheckValidBoxInd for a GPUDevice.
+namespace {
+
+// Specialization of CheckValidBoxIndex for a GPUDevice.
template <>
-inline void CheckValidBoxInd<GPUDevice>(
- OpKernelContext* context, typename TTypes<int32, 1>::ConstTensor box_ind,
- int batch) {
- const int num_boxes = box_ind.dimension(0);
+inline void RunIfBoxIndexIsValid<GPUDevice>(
+ OpKernelContext* context, typename TTypes<int32, 1>::ConstTensor box_index,
+ int batch_size, Callback compute, Callback done) {
+ const int num_boxes = box_index.dimension(0);
if (num_boxes == 0) {
+ compute();
+ done();
return;
}
- Tensor isvalid_tensor;
- OP_REQUIRES_OK(context,
- context->allocate_temp(DataTypeToEnum<bool>::value,
- TensorShape({}), &isvalid_tensor));
- typename TTypes<bool, 0>::Tensor isvalid = isvalid_tensor.tensor<bool, 0>();
+ Tensor isvalid_dev_tensor;
+ OP_REQUIRES_OK_ASYNC(
+ context,
+ context->allocate_temp(DataTypeToEnum<bool>::value, TensorShape({}),
+ &isvalid_dev_tensor),
+ done);
+ typename TTypes<bool, 0>::Tensor isvalid_dev =
+ isvalid_dev_tensor.tensor<bool, 0>();
- functor::CheckValidBoxIndHelper<GPUDevice>()(
- context->eigen_device<GPUDevice>(), box_ind, batch, isvalid);
+ // Run the actual box check on the device.
+ functor::CheckValidBoxIndexHelper<GPUDevice>()(
+ context->eigen_device<GPUDevice>(), box_index, batch_size, isvalid_dev);
+ // Copy the result back to the host.
auto* stream = context->op_device_context()->stream();
- OP_REQUIRES(context, stream, errors::Internal("No GPU stream available."));
-
- bool isvalid_host = false;
- perftools::gputools::DeviceMemoryBase isvalid_gpu(isvalid.data(),
- sizeof(bool));
- stream->ThenMemcpy(&isvalid_host, isvalid_gpu, sizeof(bool));
- stream->BlockHostUntilDone();
-
- OP_REQUIRES(context, stream->ok(),
- errors::Internal("cudaMemcpy from device to host failed"));
-
- OP_REQUIRES(context, isvalid_host,
- errors::OutOfRange("box_ind has values outside [0, batch)"));
+ OP_REQUIRES_ASYNC(context, stream,
+ errors::Internal("No GPU stream available."), done);
+ Tensor isvalid_host_tensor;
+ // Use pinned host memory on the host to avoid unnecessary
+ // synchronization.
+ AllocatorAttributes alloc_attr;
+ alloc_attr.set_on_host(true);
+ alloc_attr.set_gpu_compatible(true);
+ OP_REQUIRES_OK_ASYNC(
+ context,
+ context->allocate_temp(DataTypeToEnum<bool>::value, TensorShape({}),
+ &isvalid_host_tensor, alloc_attr),
+ done);
+ typename TTypes<bool, 0>::Tensor isvalid_host =
+ isvalid_host_tensor.tensor<bool, 0>();
+
+ perftools::gputools::DeviceMemoryBase wrapped(isvalid_dev.data(),
+ sizeof(bool));
+ const bool status = stream
+ ->ThenMemcpy(isvalid_host.data() /* destination */,
+ wrapped /* source */, sizeof(bool))
+ .ok();
+ OP_REQUIRES_ASYNC(
+ context, status,
+ errors::Internal("Failed to launch copy of isvalid from device to host."),
+ done);
+
+ auto wrapped_callback = [context, isvalid_host, compute, done]() {
+ OP_REQUIRES_ASYNC(
+ context, isvalid_host(),
+ errors::OutOfRange("box_index has values outside [0, batch_size)"),
+ done);
+ compute();
+ done();
+ };
+
+ context->device()->tensorflow_gpu_device_info()->event_mgr->ThenExecute(
+ stream, wrapped_callback);
}
+} // namespace
+
#define REGISTER_KERNEL(T) \
REGISTER_KERNEL_BUILDER(Name("CropAndResize") \
.Device(DEVICE_GPU) \