aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/core/kernels/resize_nearest_neighbor_op.cc
diff options
context:
space:
mode:
authorGravatar David G. Andersen <dga@google.com>2016-03-17 20:01:14 -0800
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2016-03-18 08:47:49 -0700
commitcfa2ac07d3cf119734d1eee8d1d081543d8152b1 (patch)
treeb209eac404dca549bb5f0f6340443fd343c56482 /tensorflow/core/kernels/resize_nearest_neighbor_op.cc
parent16064000077ef6bf2a7f93ce1a5730951d009af1 (diff)
Refactoring common checking and size computation code into a
separate struct that is shared by all of the image resizers. Normalizes the error checking across all of the resizers. Also added a max size check to nearest_neighbor - because of the floats, it starts to produce bad results after 2^24px in either direction. Not that anyone does that, but it's good to be precise about it. Change: 117516271
Diffstat (limited to 'tensorflow/core/kernels/resize_nearest_neighbor_op.cc')
-rw-r--r--tensorflow/core/kernels/resize_nearest_neighbor_op.cc65
1 files changed, 19 insertions, 46 deletions
diff --git a/tensorflow/core/kernels/resize_nearest_neighbor_op.cc b/tensorflow/core/kernels/resize_nearest_neighbor_op.cc
index 059ef83bb0..26cdac1519 100644
--- a/tensorflow/core/kernels/resize_nearest_neighbor_op.cc
+++ b/tensorflow/core/kernels/resize_nearest_neighbor_op.cc
@@ -23,6 +23,7 @@ limitations under the License.
#include "tensorflow/core/framework/tensor.h"
#include "tensorflow/core/framework/tensor_shape.h"
#include "tensorflow/core/framework/types.h"
+#include "tensorflow/core/kernels/image_resizer_state.h"
#include "tensorflow/core/lib/core/status.h"
#include "tensorflow/core/platform/logging.h"
@@ -44,56 +45,28 @@ class ResizeNearestNeighborOp : public OpKernel {
void Compute(OpKernelContext* context) override {
const Tensor& input = context->input(0);
- OP_REQUIRES(context, input.dims() == 4,
- errors::InvalidArgument("input must be 4-dimensional",
- input.shape().DebugString()));
- const Tensor& shape_t = context->input(1);
- OP_REQUIRES(context, shape_t.dims() == 1,
- errors::InvalidArgument("shape_t must be 1-dimensional",
- shape_t.shape().DebugString()));
- OP_REQUIRES(context, shape_t.NumElements() == 2,
- errors::InvalidArgument("shape_t must have two elements",
- shape_t.shape().DebugString()));
+ ImageResizerState st(align_corners_);
+ st.ValidateAndCreateOutput(context, input);
- auto sizes = shape_t.vec<int32>();
- OP_REQUIRES(context, sizes(0) > 0 && sizes(1) > 0,
- errors::InvalidArgument("shape_t's elements must be positive"));
-
- // Initialize shape to the batch size of the input, then add
- // the rest of the dimensions
- Tensor* output = nullptr;
- OP_REQUIRES_OK(
- context, context->allocate_output(0, TensorShape({input.dim_size(0), sizes(0),
- sizes(1), input.dim_size(3)}),
- &output));
+ if (!context->status().ok()) return;
- const int64 batch_size = input.dim_size(0);
- const int64 in_height = input.dim_size(1);
- const int64 in_width = input.dim_size(2);
- const int64 channels = input.dim_size(3);
- const int64 out_height = output->dim_size(1);
- const int64 out_width = output->dim_size(2);
+ OP_REQUIRES(context, st.in_height < (1 << 24) && st.in_width < (1 << 24),
+ errors::InvalidArgument("nearest neighbor requires max height "
+ "& width of 2^24"));
typename TTypes<T, 4>::ConstTensor input_data = input.tensor<T, 4>();
- typename TTypes<T, 4>::Tensor output_data = output->tensor<T, 4>();
-
- const float height_scale =
- (align_corners_ && out_height > 1)
- ? (in_height - 1) / static_cast<float>(out_height - 1)
- : in_height / static_cast<float>(out_height);
- const float width_scale =
- (align_corners_ && out_width > 1)
- ? (in_width - 1) / static_cast<float>(out_width - 1)
- : in_width / static_cast<float>(out_width);
-
- for (int b = 0; b < batch_size; ++b) {
- for (int y = 0; y < out_height; ++y) {
- const int in_y = std::min(static_cast<int64>(floorf(y * height_scale)),
- (in_height - 1));
- for (int x = 0; x < out_width; ++x) {
- const int in_x = std::min(static_cast<int64>(floorf(x * width_scale)),
- (in_width - 1));
- for (int c = 0; c < channels; ++c) {
+ typename TTypes<T, 4>::Tensor output_data = st.output->tensor<T, 4>();
+
+ for (int b = 0; b < st.batch_size; ++b) {
+ for (int y = 0; y < st.out_height; ++y) {
+ const int in_y =
+ std::min(static_cast<int64>(floorf(y * st.height_scale)),
+ (st.in_height - 1));
+ for (int x = 0; x < st.out_width; ++x) {
+ const int in_x =
+ std::min(static_cast<int64>(floorf(x * st.width_scale)),
+ (st.in_width - 1));
+ for (int c = 0; c < st.channels; ++c) {
output_data(b, y, x, c) = input_data(b, in_y, in_x, c);
}
}