aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/core/kernels/random_crop_op.cc
diff options
context:
space:
mode:
Diffstat (limited to 'tensorflow/core/kernels/random_crop_op.cc')
-rw-r--r--tensorflow/core/kernels/random_crop_op.cc103
1 files changed, 103 insertions, 0 deletions
diff --git a/tensorflow/core/kernels/random_crop_op.cc b/tensorflow/core/kernels/random_crop_op.cc
new file mode 100644
index 0000000000..4fc12e92cb
--- /dev/null
+++ b/tensorflow/core/kernels/random_crop_op.cc
@@ -0,0 +1,103 @@
+// See docs in ../ops/image_ops.cc.
+
+#include "tensorflow/core/framework/op_kernel.h"
+#include "tensorflow/core/framework/register_types.h"
+#include "tensorflow/core/framework/types.h"
+#include "tensorflow/core/lib/random/simple_philox.h"
+#include "tensorflow/core/public/tensor.h"
+#include "tensorflow/core/util/guarded_philox_random.h"
+
+namespace tensorflow {
+
+template <typename T>
+class RandomCropOp : public OpKernel {
+ public:
+ explicit RandomCropOp(OpKernelConstruction* context) : OpKernel(context) {
+ OP_REQUIRES_OK(context, generator_.Init(context));
+ }
+
+ void Compute(OpKernelContext* context) override {
+ const Tensor& input = context->input(0);
+ OP_REQUIRES(context, input.dims() == 3,
+ errors::InvalidArgument("input must be 3-dimensional",
+ input.shape().ShortDebugString()));
+ const Tensor& shape_t = context->input(1);
+ OP_REQUIRES(context, shape_t.dims() == 1,
+ errors::InvalidArgument("shape_t must be 1-dimensional",
+ shape_t.shape().ShortDebugString()));
+ OP_REQUIRES(context, shape_t.NumElements() == 2,
+ errors::InvalidArgument("shape_t must have two elements",
+ shape_t.shape().ShortDebugString()));
+
+ auto shape_vec = shape_t.vec<int64>();
+ const int32 target_height = shape_vec(0);
+ const int32 target_width = shape_vec(1);
+
+ const int32 height = input.dim_size(0);
+ const int32 width = input.dim_size(1);
+ const int32 channels = input.dim_size(2);
+
+ // Initialize shape to the batch size of the input, then add
+ // the rest of the dimensions
+ Tensor* output = nullptr;
+ const auto output_shape =
+ TensorShape({target_height, target_width, channels});
+ OP_REQUIRES_OK(context, context->allocate_output(0, output_shape, &output));
+
+ // If the target size matches the actual size, then do nothing.
+ if ((target_height == height) && (target_width == width)) {
+ *output = context->input(0);
+ }
+
+ // TODO(shlens): Implement edge case to guarantee output size dimensions.
+ // Edge case. The target dimensions are larger then the image, so
+ // zero-pad the image. This guarantees that the image will *always*
+ // be [target_height, target_width] in size.
+ OP_REQUIRES(context, width >= target_width, errors::FailedPrecondition(
+ "width must be >= target_width: width = ", width,
+ ", target_width = ", target_width));
+ OP_REQUIRES(context, height >= target_height, errors::FailedPrecondition(
+ "height must be >= target_height: height = ", height,
+ ", target_height = ", target_height));
+
+ int32 offset_height = 0;
+ int32 offset_width = 0;
+
+ auto local_gen = generator_.ReserveSamples32(2);
+ random::SimplePhilox random(&local_gen);
+
+ if (width > target_width) {
+ offset_width = random.Rand32() % (width - target_width + 1);
+ }
+ if (height > target_height) {
+ offset_height = random.Rand32() % (height - target_height + 1);
+ }
+
+ // TODO(shlens): Do this more efficiently with memcpy once padding is
+ // available for smaller images.
+ typename TTypes<T, 3>::ConstTensor input_data = input.tensor<T, 3>();
+ typename TTypes<T, 3>::Tensor output_data = output->tensor<T, 3>();
+
+ for (int y = 0; y < target_height; ++y) {
+ for (int x = 0; x < target_width; ++x) {
+ for (int c = 0; c < channels; ++c) {
+ output_data(y, x, c) =
+ input_data(y + offset_height, x + offset_width, c);
+ }
+ }
+ }
+ }
+
+ private:
+ GuardedPhiloxRandom generator_;
+};
+
+#define REGISTER_KERNELS(type) \
+ REGISTER_KERNEL_BUILDER( \
+ Name("RandomCrop").Device(DEVICE_CPU).TypeConstraint<type>("T"), \
+ RandomCropOp<type>)
+
+TF_CALL_REAL_NUMBER_TYPES(REGISTER_KERNELS);
+#undef REGISTER_KERNELS
+
+} // namespace tensorflow