aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/core
diff options
context:
space:
mode:
authorGravatar A. Unique TensorFlower <gardener@tensorflow.org>2016-06-29 15:32:57 -0800
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2016-06-29 16:48:22 -0700
commit1d92cfcbf5c157b3e4069741ae5bdbbea6666dc5 (patch)
tree414fef12a20b9be5e9b5826327bfb7c9424d5e5b /tensorflow/core
parente8974bac93f18e249676f4cd2e9bdbac2c813add (diff)
Correct bug in crop_and_resize which caused failures to some tests.
Change: 126246458
Diffstat (limited to 'tensorflow/core')
-rw-r--r--tensorflow/core/kernels/crop_and_resize_op.cc104
-rw-r--r--tensorflow/core/kernels/crop_and_resize_op.h6
-rw-r--r--tensorflow/core/kernels/crop_and_resize_op_gpu.cu.cc153
-rw-r--r--tensorflow/core/kernels/crop_and_resize_op_test.cc31
4 files changed, 191 insertions, 103 deletions
diff --git a/tensorflow/core/kernels/crop_and_resize_op.cc b/tensorflow/core/kernels/crop_and_resize_op.cc
index 4e50a04190..caf73420ba 100644
--- a/tensorflow/core/kernels/crop_and_resize_op.cc
+++ b/tensorflow/core/kernels/crop_and_resize_op.cc
@@ -42,6 +42,10 @@ static inline void ParseAndCheckBoxSizes(OpKernelContext* context,
const Tensor& boxes,
const Tensor& box_ind,
int* num_boxes) {
+ if (boxes.NumElements() == 0 && box_ind.NumElements() == 0) {
+ *num_boxes = 0;
+ return;
+ }
// The shape of 'boxes' is [num_boxes, 4].
OP_REQUIRES(context, boxes.dims() == 2,
errors::InvalidArgument("boxes must be 2-D",
@@ -132,9 +136,13 @@ class CropAndResizeOp : public OpKernel {
CheckValidBoxInd<Device>(context, box_ind_data, batch);
- functor::CropAndResize<Device, T>()(context->eigen_device<Device>(),
- image_data, boxes_data, box_ind_data,
- extrapolation_value_, crops_data);
+ bool status = functor::CropAndResize<Device, T>()(
+ context->eigen_device<Device>(), image_data, boxes_data, box_ind_data,
+ extrapolation_value_, crops_data);
+ if (!status) {
+ context->SetStatus(
+ errors::Internal("Failed launch CropAndResizeKernel."));
+ }
}
private:
@@ -145,11 +153,12 @@ class CropAndResizeOp : public OpKernel {
namespace functor {
template <typename T>
struct CropAndResize<CPUDevice, T> {
- void operator()(const CPUDevice& d, typename TTypes<T, 4>::ConstTensor image,
+ bool operator()(const CPUDevice& d, typename TTypes<T, 4>::ConstTensor image,
typename TTypes<float, 2>::ConstTensor boxes,
typename TTypes<int32, 1>::ConstTensor box_ind,
float extrapolation_value,
typename TTypes<float, 4>::Tensor crops) {
+ const int batch = image.dimension(0);
const int image_height = image.dimension(1);
const int image_width = image.dimension(2);
@@ -163,7 +172,11 @@ struct CropAndResize<CPUDevice, T> {
const float x1 = boxes(b, 1);
const float y2 = boxes(b, 2);
const float x2 = boxes(b, 3);
+
const int32 b_in = box_ind(b);
+ if (b_in < 0 || b_in >= batch) {
+ continue;
+ }
const float height_scale =
(crop_height > 1) ? (y2 - y1) * (image_height - 1) / (crop_height - 1)
@@ -217,6 +230,7 @@ struct CropAndResize<CPUDevice, T> {
}
}
}
+ return true;
}
};
} // namespace functor
@@ -235,6 +249,7 @@ class CropAndResizeGradImageOp : public OpKernel {
void Compute(OpKernelContext* context) override {
// The shape of 'grads' is [num_boxes, crop_height, crop_width, depth].
const Tensor& grads = context->input(0);
+
OP_REQUIRES(context, grads.dims() == 4,
errors::InvalidArgument("grads image must be 4-D",
grads.shape().DebugString()));
@@ -294,9 +309,13 @@ class CropAndResizeGradImageOp : public OpKernel {
CheckValidBoxInd<Device>(context, box_ind_data, batch);
- functor::CropAndResizeBackpropImage<Device, T>()(
+ bool status = functor::CropAndResizeBackpropImage<Device, T>()(
context->eigen_device<Device>(), grads_data, boxes_data, box_ind_data,
output_data);
+ if (!status) {
+ context->SetStatus(
+ errors::Internal("Failed launch CropAndResizeBackpropImageKernel."));
+ }
}
};
@@ -304,11 +323,12 @@ class CropAndResizeGradImageOp : public OpKernel {
namespace functor {
template <typename T>
struct CropAndResizeBackpropImage<CPUDevice, T> {
- void operator()(const CPUDevice& d,
+ bool operator()(const CPUDevice& d,
typename TTypes<float, 4>::ConstTensor grads,
typename TTypes<float, 2>::ConstTensor boxes,
typename TTypes<int32, 1>::ConstTensor box_ind,
typename TTypes<T, 4>::Tensor grads_image) {
+ const int batch = grads_image.dimension(0);
const int image_height = grads_image.dimension(1);
const int image_width = grads_image.dimension(2);
@@ -324,7 +344,11 @@ struct CropAndResizeBackpropImage<CPUDevice, T> {
const float x1 = boxes(b, 1);
const float y2 = boxes(b, 2);
const float x2 = boxes(b, 3);
+
const int32 b_in = box_ind(b);
+ if (b_in < 0 || b_in >= batch) {
+ continue;
+ }
const float height_scale =
(crop_height > 1) ? (y2 - y1) * (image_height - 1) / (crop_height - 1)
@@ -370,6 +394,7 @@ struct CropAndResizeBackpropImage<CPUDevice, T> {
}
}
}
+ return true;
}
};
} // namespace functor
@@ -388,6 +413,7 @@ class CropAndResizeGradBoxesOp : public OpKernel {
void Compute(OpKernelContext* context) override {
// The shape of 'grads' is [num_boxes, crop_height, crop_width, depth].
const Tensor& grads = context->input(0);
+
OP_REQUIRES(context, grads.dims() == 4,
errors::InvalidArgument("grads image must be 4-D",
grads.shape().DebugString()));
@@ -441,9 +467,13 @@ class CropAndResizeGradBoxesOp : public OpKernel {
CheckValidBoxInd<Device>(context, box_ind_data, batch);
- functor::CropAndResizeBackpropBoxes<Device, T>()(
+ bool status = functor::CropAndResizeBackpropBoxes<Device, T>()(
context->eigen_device<Device>(), grads_data, image_data, boxes_data,
box_ind_data, output_data);
+ if (!status) {
+ context->SetStatus(
+ errors::Internal("Failed launch CropAndResizeBackpropBoxesKernel."));
+ }
}
};
@@ -451,12 +481,13 @@ class CropAndResizeGradBoxesOp : public OpKernel {
namespace functor {
template <typename T>
struct CropAndResizeBackpropBoxes<CPUDevice, T> {
- void operator()(const CPUDevice& d,
+ bool operator()(const CPUDevice& d,
typename TTypes<float, 4>::ConstTensor grads,
typename TTypes<T, 4>::ConstTensor image,
typename TTypes<float, 2>::ConstTensor boxes,
typename TTypes<int32, 1>::ConstTensor box_ind,
typename TTypes<float, 2>::Tensor grads_boxes) {
+ const int batch = image.dimension(0);
const int image_height = image.dimension(1);
const int image_width = image.dimension(2);
@@ -472,7 +503,11 @@ struct CropAndResizeBackpropBoxes<CPUDevice, T> {
const float x1 = boxes(b, 1);
const float y2 = boxes(b, 2);
const float x2 = boxes(b, 3);
+
const int32 b_in = box_ind(b);
+ if (b_in < 0 || b_in >= batch) {
+ continue;
+ }
const float height_ratio =
(crop_height > 1)
@@ -547,6 +582,7 @@ struct CropAndResizeBackpropBoxes<CPUDevice, T> {
}
}
}
+ return true;
}
};
} // namespace functor
@@ -563,37 +599,25 @@ inline void CheckValidBoxInd<CPUDevice>(
}
}
-#define REGISTER_KERNEL(T) \
- REGISTER_KERNEL_BUILDER(Name("CropAndResize") \
- .Device(DEVICE_CPU) \
- .TypeConstraint<T>("T") \
- .HostMemory("crop_size"), \
- CropAndResizeOp<CPUDevice, T>);
-
-TF_CALL_REAL_NUMBER_TYPES(REGISTER_KERNEL);
-
-#undef REGISTER_KERNEL
-
-#define REGISTER_KERNEL(T) \
- REGISTER_KERNEL_BUILDER(Name("CropAndResizeGradImage") \
- .Device(DEVICE_CPU) \
- .TypeConstraint<T>("T") \
- .HostMemory("image_size"), \
- CropAndResizeGradImageOp<CPUDevice, T>);
-
-TF_CALL_half(REGISTER_KERNEL);
-TF_CALL_float(REGISTER_KERNEL);
-TF_CALL_double(REGISTER_KERNEL);
-
-#undef REGISTER_KERNEL
-
-#define REGISTER_KERNEL(T) \
- REGISTER_KERNEL_BUILDER(Name("CropAndResizeGradBoxes") \
- .Device(DEVICE_CPU) \
- .TypeConstraint<T>("T"), \
+#define REGISTER_KERNEL(T) \
+ REGISTER_KERNEL_BUILDER(Name("CropAndResize") \
+ .Device(DEVICE_CPU) \
+ .TypeConstraint<T>("T") \
+ .HostMemory("crop_size"), \
+ CropAndResizeOp<CPUDevice, T>); \
+ \
+ REGISTER_KERNEL_BUILDER(Name("CropAndResizeGradImage") \
+ .Device(DEVICE_CPU) \
+ .TypeConstraint<T>("T") \
+ .HostMemory("image_size"), \
+ CropAndResizeGradImageOp<CPUDevice, T>); \
+ \
+ REGISTER_KERNEL_BUILDER(Name("CropAndResizeGradBoxes") \
+ .Device(DEVICE_CPU) \
+ .TypeConstraint<T>("T"), \
CropAndResizeGradBoxesOp<CPUDevice, T>);
-TF_CALL_REAL_NUMBER_TYPES(REGISTER_KERNEL);
+TF_CALL_float(REGISTER_KERNEL);
#undef REGISTER_KERNEL
@@ -613,6 +637,10 @@ template <>
inline void CheckValidBoxInd<GPUDevice>(
OpKernelContext* context, typename TTypes<int32, 1>::ConstTensor box_ind,
int batch) {
+ const int num_boxes = box_ind.dimension(0);
+ if (num_boxes == 0) {
+ return;
+ }
Tensor isvalid_tensor;
OP_REQUIRES_OK(context,
context->allocate_temp(DataTypeToEnum<bool>::value,
@@ -657,7 +685,7 @@ inline void CheckValidBoxInd<GPUDevice>(
.TypeConstraint<T>("T"), \
CropAndResizeGradBoxesOp<GPUDevice, T>);
-TF_CALL_GPU_NUMBER_TYPES_NO_HALF(REGISTER_KERNEL);
+TF_CALL_float(REGISTER_KERNEL);
#undef REGISTER_KERNEL
diff --git a/tensorflow/core/kernels/crop_and_resize_op.h b/tensorflow/core/kernels/crop_and_resize_op.h
index 9278893704..22df1bdd56 100644
--- a/tensorflow/core/kernels/crop_and_resize_op.h
+++ b/tensorflow/core/kernels/crop_and_resize_op.h
@@ -26,7 +26,7 @@ namespace functor {
template <typename Device, typename T>
struct CropAndResize {
// We assume that the tensor sizes are correct.
- void operator()(const Device& d, typename TTypes<T, 4>::ConstTensor image,
+ bool operator()(const Device& d, typename TTypes<T, 4>::ConstTensor image,
typename TTypes<float, 2>::ConstTensor boxes,
typename TTypes<int32, 1>::ConstTensor box_ind,
float extrapolation_value,
@@ -36,7 +36,7 @@ struct CropAndResize {
template <typename Device, typename T>
struct CropAndResizeBackpropImage {
// We assume that the tensor sizes are correct.
- void operator()(const Device& d, typename TTypes<float, 4>::ConstTensor grads,
+ bool operator()(const Device& d, typename TTypes<float, 4>::ConstTensor grads,
typename TTypes<float, 2>::ConstTensor boxes,
typename TTypes<int32, 1>::ConstTensor box_ind,
typename TTypes<T, 4>::Tensor grads_image);
@@ -45,7 +45,7 @@ struct CropAndResizeBackpropImage {
template <typename Device, typename T>
struct CropAndResizeBackpropBoxes {
// We assume that the tensor sizes are correct.
- void operator()(const Device& d, typename TTypes<float, 4>::ConstTensor grads,
+ bool operator()(const Device& d, typename TTypes<float, 4>::ConstTensor grads,
typename TTypes<T, 4>::ConstTensor image,
typename TTypes<float, 2>::ConstTensor boxes,
typename TTypes<int32, 1>::ConstTensor box_ind,
diff --git a/tensorflow/core/kernels/crop_and_resize_op_gpu.cu.cc b/tensorflow/core/kernels/crop_and_resize_op_gpu.cu.cc
index 3759a8cb4c..75146b28e6 100644
--- a/tensorflow/core/kernels/crop_and_resize_op_gpu.cu.cc
+++ b/tensorflow/core/kernels/crop_and_resize_op_gpu.cu.cc
@@ -33,27 +33,30 @@ typedef Eigen::GpuDevice GPUDevice;
namespace {
template <typename T>
-__global__ void CropAndResizeKernel(const int32 nthreads, const T* image_ptr,
- const float* boxes_ptr,
- const int32* box_ind_ptr, int num_boxes,
- int image_height, int image_width,
- int crop_height, int crop_width, int depth,
- float extrapolation_value,
- float* crops_ptr) {
+__global__ void CropAndResizeKernel(
+ const int32 nthreads, const T* image_ptr, const float* boxes_ptr,
+ const int32* box_ind_ptr, int num_boxes, int batch, int image_height,
+ int image_width, int crop_height, int crop_width, int depth,
+ float extrapolation_value, float* crops_ptr) {
CUDA_1D_KERNEL_LOOP(out_idx, nthreads) {
// out_idx = d + depth * (w + crop_width * (h + crop_height * b))
- const int d = out_idx % depth;
- const int out_idx2 = out_idx / depth;
- const int x = out_idx2 % crop_width;
- const int out_idx3 = out_idx2 / crop_width;
- const int y = out_idx3 % crop_height;
- const int b = out_idx3 / crop_height;
+ int idx = out_idx;
+ const int d = idx % depth;
+ idx /= depth;
+ const int x = idx % crop_width;
+ idx /= crop_width;
+ const int y = idx % crop_height;
+ const int b = idx / crop_height;
const float y1 = boxes_ptr[b * 4];
const float x1 = boxes_ptr[b * 4 + 1];
const float y2 = boxes_ptr[b * 4 + 2];
const float x2 = boxes_ptr[b * 4 + 3];
+
const int32 b_in = box_ind_ptr[b];
+ if (b_in < 0 || b_in >= batch) {
+ continue;
+ }
const float height_scale =
(crop_height > 1) ? (y2 - y1) * (image_height - 1) / (crop_height - 1)
@@ -66,7 +69,7 @@ __global__ void CropAndResizeKernel(const int32 nthreads, const T* image_ptr,
: 0.5 * (y1 + y2) * (image_height - 1);
if (in_y < 0 || in_y > image_height - 1) {
crops_ptr[out_idx] = extrapolation_value;
- return;
+ continue;
}
const float in_x = (crop_width > 1)
@@ -74,7 +77,7 @@ __global__ void CropAndResizeKernel(const int32 nthreads, const T* image_ptr,
: 0.5 * (x1 + x2) * (image_width - 1);
if (in_x < 0 || in_x > image_width - 1) {
crops_ptr[out_idx] = extrapolation_value;
- return;
+ continue;
}
const int top_y_index = floorf(in_y);
@@ -114,22 +117,28 @@ __global__ void CropAndResizeKernel(const int32 nthreads, const T* image_ptr,
template <typename T>
__global__ void CropAndResizeBackpropImageKernel(
const int32 nthreads, const float* grads_ptr, const float* boxes_ptr,
- const int32* box_ind_ptr, int num_boxes, int image_height, int image_width,
- int crop_height, int crop_width, int depth, T* grads_image_ptr) {
+ const int32* box_ind_ptr, int num_boxes, int batch, int image_height,
+ int image_width, int crop_height, int crop_width, int depth,
+ T* grads_image_ptr) {
CUDA_1D_KERNEL_LOOP(out_idx, nthreads) {
// out_idx = d + depth * (w + crop_width * (h + crop_height * b))
- const int d = out_idx % depth;
- const int out_idx2 = out_idx / depth;
- const int x = out_idx2 % crop_width;
- const int out_idx3 = out_idx2 / crop_width;
- const int y = out_idx3 % crop_height;
- const int b = out_idx3 / crop_height;
+ int idx = out_idx;
+ const int d = idx % depth;
+ idx /= depth;
+ const int x = idx % crop_width;
+ idx /= crop_width;
+ const int y = idx % crop_height;
+ const int b = idx / crop_height;
const float y1 = boxes_ptr[b * 4];
const float x1 = boxes_ptr[b * 4 + 1];
const float y2 = boxes_ptr[b * 4 + 2];
const float x2 = boxes_ptr[b * 4 + 3];
+
const int32 b_in = box_ind_ptr[b];
+ if (b_in < 0 || b_in >= batch) {
+ continue;
+ }
const float height_scale =
(crop_height > 1) ? (y2 - y1) * (image_height - 1) / (crop_height - 1)
@@ -141,14 +150,14 @@ __global__ void CropAndResizeBackpropImageKernel(
? y1 * (image_height - 1) + y * height_scale
: 0.5 * (y1 + y2) * (image_height - 1);
if (in_y < 0 || in_y > image_height - 1) {
- return;
+ continue;
}
const float in_x = (crop_width > 1)
? x1 * (image_width - 1) + x * width_scale
: 0.5 * (x1 + x2) * (image_width - 1);
if (in_x < 0 || in_x > image_width - 1) {
- return;
+ continue;
}
const int top_y_index = floorf(in_y);
@@ -192,23 +201,28 @@ __global__ void CropAndResizeBackpropImageKernel(
template <typename T>
__global__ void CropAndResizeBackpropBoxesKernel(
const int32 nthreads, const float* grads_ptr, const T* image_ptr,
- const float* boxes_ptr, const int32* box_ind_ptr, int num_boxes,
+ const float* boxes_ptr, const int32* box_ind_ptr, int num_boxes, int batch,
int image_height, int image_width, int crop_height, int crop_width,
int depth, float* grads_boxes_ptr) {
CUDA_1D_KERNEL_LOOP(out_idx, nthreads) {
// out_idx = d + depth * (w + crop_width * (h + crop_height * b))
- const int d = out_idx % depth;
- const int out_idx2 = out_idx / depth;
- const int x = out_idx2 % crop_width;
- const int out_idx3 = out_idx2 / crop_width;
- const int y = out_idx3 % crop_height;
- const int b = out_idx3 / crop_height;
+ int idx = out_idx;
+ const int d = idx % depth;
+ idx /= depth;
+ const int x = idx % crop_width;
+ idx /= crop_width;
+ const int y = idx % crop_height;
+ const int b = idx / crop_height;
const float y1 = boxes_ptr[b * 4];
const float x1 = boxes_ptr[b * 4 + 1];
const float y2 = boxes_ptr[b * 4 + 2];
const float x2 = boxes_ptr[b * 4 + 3];
+
const int32 b_in = box_ind_ptr[b];
+ if (b_in < 0 || b_in >= batch) {
+ continue;
+ }
const float height_ratio =
(crop_height > 1)
@@ -226,14 +240,14 @@ __global__ void CropAndResizeBackpropBoxesKernel(
? y1 * (image_height - 1) + y * height_scale
: 0.5 * (y1 + y2) * (image_height - 1);
if (in_y < 0 || in_y > image_height - 1) {
- return;
+ continue;
}
const float in_x = (crop_width > 1)
? x1 * (image_width - 1) + x * width_scale
: 0.5 * (x1 + x2) * (image_width - 1);
if (in_x < 0 || in_x > image_width - 1) {
- return;
+ continue;
}
const int top_y_index = floorf(in_y);
@@ -306,11 +320,12 @@ namespace functor {
template <typename T>
struct CropAndResize<GPUDevice, T> {
- void operator()(const GPUDevice& d, typename TTypes<T, 4>::ConstTensor image,
+ bool operator()(const GPUDevice& d, typename TTypes<T, 4>::ConstTensor image,
typename TTypes<float, 2>::ConstTensor boxes,
typename TTypes<int32, 1>::ConstTensor box_ind,
float extrapolation_value,
typename TTypes<float, 4>::Tensor crops) {
+ const int batch = image.dimension(0);
const int image_height = image.dimension(1);
const int image_width = image.dimension(2);
@@ -320,19 +335,22 @@ struct CropAndResize<GPUDevice, T> {
const int depth = crops.dimension(3);
const int total_count = num_boxes * crop_height * crop_width * depth;
- CudaLaunchConfig config = GetCudaLaunchConfig(total_count, d);
- CropAndResizeKernel<<<config.block_count, config.thread_per_block, 0,
- d.stream()>>>(
- config.virtual_thread_count, image.data(), boxes.data(), box_ind.data(),
- num_boxes, image_height, image_width, crop_height, crop_width, depth,
- extrapolation_value, crops.data());
+ if (total_count > 0) {
+ CudaLaunchConfig config = GetCudaLaunchConfig(total_count, d);
+ CropAndResizeKernel<<<config.block_count, config.thread_per_block, 0,
+ d.stream()>>>(
+ config.virtual_thread_count, image.data(), boxes.data(),
+ box_ind.data(), num_boxes, batch, image_height, image_width,
+ crop_height, crop_width, depth, extrapolation_value, crops.data());
+ }
+ return d.ok();
}
};
template <typename T>
struct CropAndResizeBackpropImage<GPUDevice, T> {
- void operator()(const GPUDevice& d,
+ bool operator()(const GPUDevice& d,
typename TTypes<float, 4>::ConstTensor grads,
typename TTypes<float, 2>::ConstTensor boxes,
typename TTypes<int32, 1>::ConstTensor box_ind,
@@ -351,29 +369,35 @@ struct CropAndResizeBackpropImage<GPUDevice, T> {
// Initialize grads_image with all zeros.
total_count = batch * image_height * image_width * depth;
- config = GetCudaLaunchConfig(total_count, d);
- SetZero<<<config.block_count, config.thread_per_block, 0, d.stream()>>>(
- total_count, grads_image.data());
+ if (total_count > 0) {
+ config = GetCudaLaunchConfig(total_count, d);
+ SetZero<<<config.block_count, config.thread_per_block, 0, d.stream()>>>(
+ config.virtual_thread_count, grads_image.data());
+ }
// Accumulate.
total_count = num_boxes * crop_height * crop_width * depth;
- config = GetCudaLaunchConfig(total_count, d);
- CropAndResizeBackpropImageKernel<<<
- config.block_count, config.thread_per_block, 0, d.stream()>>>(
- config.virtual_thread_count, grads.data(), boxes.data(), box_ind.data(),
- num_boxes, image_height, image_width, crop_height, crop_width, depth,
- grads_image.data());
+ if (total_count > 0) {
+ config = GetCudaLaunchConfig(total_count, d);
+ CropAndResizeBackpropImageKernel<<<
+ config.block_count, config.thread_per_block, 0, d.stream()>>>(
+ config.virtual_thread_count, grads.data(), boxes.data(),
+ box_ind.data(), num_boxes, batch, image_height, image_width,
+ crop_height, crop_width, depth, grads_image.data());
+ }
+ return d.ok();
}
};
template <typename T>
struct CropAndResizeBackpropBoxes<GPUDevice, T> {
- void operator()(const GPUDevice& d,
+ bool operator()(const GPUDevice& d,
typename TTypes<float, 4>::ConstTensor grads,
typename TTypes<T, 4>::ConstTensor image,
typename TTypes<float, 2>::ConstTensor boxes,
typename TTypes<int32, 1>::ConstTensor box_ind,
typename TTypes<float, 2>::Tensor grads_boxes) {
+ const int batch = image.dimension(0);
const int image_height = image.dimension(1);
const int image_width = image.dimension(2);
@@ -387,18 +411,23 @@ struct CropAndResizeBackpropBoxes<GPUDevice, T> {
// Initialize grads_boxes with all zeros.
total_count = num_boxes * 4;
- config = GetCudaLaunchConfig(total_count, d);
- SetZero<<<config.block_count, config.thread_per_block, 0, d.stream()>>>(
- total_count, grads_boxes.data());
+ if (total_count > 0) {
+ config = GetCudaLaunchConfig(total_count, d);
+ SetZero<<<config.block_count, config.thread_per_block, 0, d.stream()>>>(
+ config.virtual_thread_count, grads_boxes.data());
+ }
// Accumulate.
total_count = num_boxes * crop_height * crop_width * depth;
- config = GetCudaLaunchConfig(total_count, d);
- CropAndResizeBackpropBoxesKernel<<<
- config.block_count, config.thread_per_block, 0, d.stream()>>>(
- config.virtual_thread_count, grads.data(), image.data(), boxes.data(),
- box_ind.data(), num_boxes, image_height, image_width, crop_height,
- crop_width, depth, grads_boxes.data());
+ if (total_count > 0) {
+ config = GetCudaLaunchConfig(total_count, d);
+ CropAndResizeBackpropBoxesKernel<<<
+ config.block_count, config.thread_per_block, 0, d.stream()>>>(
+ config.virtual_thread_count, grads.data(), image.data(), boxes.data(),
+ box_ind.data(), num_boxes, batch, image_height, image_width,
+ crop_height, crop_width, depth, grads_boxes.data());
+ }
+ return d.ok();
}
};
@@ -407,7 +436,7 @@ struct CropAndResizeBackpropBoxes<GPUDevice, T> {
template struct CropAndResizeBackpropImage<GPUDevice, T>; \
template struct CropAndResizeBackpropBoxes<GPUDevice, T>;
-TF_CALL_GPU_NUMBER_TYPES_NO_HALF(DEFINE_GPU_SPECS);
+TF_CALL_float(DEFINE_GPU_SPECS);
#undef DEFINE_GPU_SPECS
diff --git a/tensorflow/core/kernels/crop_and_resize_op_test.cc b/tensorflow/core/kernels/crop_and_resize_op_test.cc
index 38f3c1adb2..68e077e44d 100644
--- a/tensorflow/core/kernels/crop_and_resize_op_test.cc
+++ b/tensorflow/core/kernels/crop_and_resize_op_test.cc
@@ -189,6 +189,24 @@ TEST_F(CropAndResizeOpTest, TestCropAndResize2x2To3x3Extrapolated) {
test::ExpectTensorEqual<float>(expected, *GetOutput(0));
}
+TEST_F(CropAndResizeOpTest, TestCropAndResize2x2To3x3NoCrop) {
+ MakeOp(0);
+ // Input:
+ // 1, 2
+ // 3, 4
+ AddInputFromArray<float>(TensorShape({1, 2, 2, 1}), {1, 2, 3, 4});
+ AddInputFromArray<float>(TensorShape({0, 4}), {});
+ AddInputFromArray<int32>(TensorShape({0}), {});
+ AddInputFromArray<int32>(TensorShape({2}), {3, 3});
+ TF_ASSERT_OK(RunOpKernel());
+
+ Tensor expected(allocator(), DT_FLOAT, TensorShape({0, 3, 3, 1}));
+ // clang-format off
+ test::FillValues<float>(&expected, {});
+ // clang-format on
+ test::ExpectTensorEqual<float>(expected, *GetOutput(0));
+}
+
TEST_F(CropAndResizeOpTest, TestInvalidInputShape) {
MakeOp(0);
AddInputFromArray<float>(TensorShape({2, 2, 1}), {1, 2, 3, 4});
@@ -201,6 +219,19 @@ TEST_F(CropAndResizeOpTest, TestInvalidInputShape) {
<< s;
}
+TEST_F(CropAndResizeOpTest, TestInvalidBoxIndexShape) {
+ MakeOp(0);
+ AddInputFromArray<float>(TensorShape({1, 2, 2, 1}), {1, 2, 3, 4});
+ AddInputFromArray<float>(TensorShape({1, 4}), {0, 0, 1, 1});
+ AddInputFromArray<int32>(TensorShape({2}), {0, 0});
+ AddInputFromArray<int32>(TensorShape({2}), {4, 4});
+ Status s = RunOpKernel();
+ ASSERT_FALSE(s.ok());
+ EXPECT_TRUE(
+ StringPiece(s.ToString()).contains("box_ind has incompatible shape"))
+ << s;
+}
+
TEST_F(CropAndResizeOpTest, TestInvalidBoxIndex) {
MakeOp(0);
AddInputFromArray<float>(TensorShape({1, 2, 2, 1}), {1, 2, 3, 4});