diff options
author | David G. Andersen <dga@google.com> | 2016-03-17 20:01:14 -0800 |
---|---|---|
committer | TensorFlower Gardener <gardener@tensorflow.org> | 2016-03-18 08:47:49 -0700 |
commit | cfa2ac07d3cf119734d1eee8d1d081543d8152b1 (patch) | |
tree | b209eac404dca549bb5f0f6340443fd343c56482 /tensorflow/core/kernels/resize_nearest_neighbor_op.cc | |
parent | 16064000077ef6bf2a7f93ce1a5730951d009af1 (diff) |
Refactoring common checking and size computation code into a
separate struct that is shared by all of the image resizers.
Normalizes the error checking across all of the resizers.
Also added a max size check to nearest_neighbor - because of
the floats, it starts to produce bad results after 2^24px
in either direction. Not that anyone does that, but it's good
to be precise about it.
Change: 117516271
Diffstat (limited to 'tensorflow/core/kernels/resize_nearest_neighbor_op.cc')
-rw-r--r-- | tensorflow/core/kernels/resize_nearest_neighbor_op.cc | 65 |
1 files changed, 19 insertions, 46 deletions
diff --git a/tensorflow/core/kernels/resize_nearest_neighbor_op.cc b/tensorflow/core/kernels/resize_nearest_neighbor_op.cc index 059ef83bb0..26cdac1519 100644 --- a/tensorflow/core/kernels/resize_nearest_neighbor_op.cc +++ b/tensorflow/core/kernels/resize_nearest_neighbor_op.cc @@ -23,6 +23,7 @@ limitations under the License. #include "tensorflow/core/framework/tensor.h" #include "tensorflow/core/framework/tensor_shape.h" #include "tensorflow/core/framework/types.h" +#include "tensorflow/core/kernels/image_resizer_state.h" #include "tensorflow/core/lib/core/status.h" #include "tensorflow/core/platform/logging.h" @@ -44,56 +45,28 @@ class ResizeNearestNeighborOp : public OpKernel { void Compute(OpKernelContext* context) override { const Tensor& input = context->input(0); - OP_REQUIRES(context, input.dims() == 4, - errors::InvalidArgument("input must be 4-dimensional", - input.shape().DebugString())); - const Tensor& shape_t = context->input(1); - OP_REQUIRES(context, shape_t.dims() == 1, - errors::InvalidArgument("shape_t must be 1-dimensional", - shape_t.shape().DebugString())); - OP_REQUIRES(context, shape_t.NumElements() == 2, - errors::InvalidArgument("shape_t must have two elements", - shape_t.shape().DebugString())); + ImageResizerState st(align_corners_); + st.ValidateAndCreateOutput(context, input); - auto sizes = shape_t.vec<int32>(); - OP_REQUIRES(context, sizes(0) > 0 && sizes(1) > 0, - errors::InvalidArgument("shape_t's elements must be positive")); - - // Initialize shape to the batch size of the input, then add - // the rest of the dimensions - Tensor* output = nullptr; - OP_REQUIRES_OK( - context, context->allocate_output(0, TensorShape({input.dim_size(0), sizes(0), - sizes(1), input.dim_size(3)}), - &output)); + if (!context->status().ok()) return; - const int64 batch_size = input.dim_size(0); - const int64 in_height = input.dim_size(1); - const int64 in_width = input.dim_size(2); - const int64 channels = input.dim_size(3); - const int64 out_height = output->dim_size(1); - const int64 out_width = output->dim_size(2); + OP_REQUIRES(context, st.in_height < (1 << 24) && st.in_width < (1 << 24), + errors::InvalidArgument("nearest neighbor requires max height " + "& width of 2^24")); typename TTypes<T, 4>::ConstTensor input_data = input.tensor<T, 4>(); - typename TTypes<T, 4>::Tensor output_data = output->tensor<T, 4>(); - - const float height_scale = - (align_corners_ && out_height > 1) - ? (in_height - 1) / static_cast<float>(out_height - 1) - : in_height / static_cast<float>(out_height); - const float width_scale = - (align_corners_ && out_width > 1) - ? (in_width - 1) / static_cast<float>(out_width - 1) - : in_width / static_cast<float>(out_width); - - for (int b = 0; b < batch_size; ++b) { - for (int y = 0; y < out_height; ++y) { - const int in_y = std::min(static_cast<int64>(floorf(y * height_scale)), - (in_height - 1)); - for (int x = 0; x < out_width; ++x) { - const int in_x = std::min(static_cast<int64>(floorf(x * width_scale)), - (in_width - 1)); - for (int c = 0; c < channels; ++c) { + typename TTypes<T, 4>::Tensor output_data = st.output->tensor<T, 4>(); + + for (int b = 0; b < st.batch_size; ++b) { + for (int y = 0; y < st.out_height; ++y) { + const int in_y = + std::min(static_cast<int64>(floorf(y * st.height_scale)), + (st.in_height - 1)); + for (int x = 0; x < st.out_width; ++x) { + const int in_x = + std::min(static_cast<int64>(floorf(x * st.width_scale)), + (st.in_width - 1)); + for (int c = 0; c < st.channels; ++c) { output_data(b, y, x, c) = input_data(b, in_y, in_x, c); } } |