aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/contrib/resampler
diff options
context:
space:
mode:
authorGravatar A. Unique TensorFlower <gardener@tensorflow.org>2017-06-26 05:47:31 -0700
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2017-06-26 05:51:07 -0700
commita80c8b583fa3d619b358a91aefd05069227b8967 (patch)
treef7270dedb338cb1ef5306a63d963416a52483aa1 /tensorflow/contrib/resampler
parent7a06e0af350e3e61bbf0ebee66d79ab15808c575 (diff)
Move resampler from sonnet to contrib.
PiperOrigin-RevId: 160134565
Diffstat (limited to 'tensorflow/contrib/resampler')
-rw-r--r--tensorflow/contrib/resampler/BUILD92
-rw-r--r--tensorflow/contrib/resampler/__init__.py26
-rw-r--r--tensorflow/contrib/resampler/kernels/resampler_ops.cc465
-rw-r--r--tensorflow/contrib/resampler/kernels/resampler_ops.h68
-rw-r--r--tensorflow/contrib/resampler/kernels/resampler_ops_gpu.cu.cc310
-rw-r--r--tensorflow/contrib/resampler/ops/resampler_ops.cc59
-rw-r--r--tensorflow/contrib/resampler/python/__init__.py19
-rw-r--r--tensorflow/contrib/resampler/python/ops/resampler_ops.py69
-rw-r--r--tensorflow/contrib/resampler/python/ops/resampler_ops_test.py270
9 files changed, 1378 insertions, 0 deletions
diff --git a/tensorflow/contrib/resampler/BUILD b/tensorflow/contrib/resampler/BUILD
new file mode 100644
index 0000000000..1b9efd1ecd
--- /dev/null
+++ b/tensorflow/contrib/resampler/BUILD
@@ -0,0 +1,92 @@
+licenses(["notice"]) # Apache 2.0 License
+
+exports_files(["LICENSE"])
+
+package(default_visibility = ["//visibility:public"])
+
+load(
+ "//tensorflow:tensorflow.bzl",
+ "tf_custom_op_library",
+ "tf_custom_op_py_library",
+ "tf_gen_op_libs",
+ "tf_gen_op_wrapper_py",
+ "tf_kernel_library",
+)
+load("//tensorflow:tensorflow.bzl", "cuda_py_test")
+
+tf_custom_op_py_library(
+ name = "resampler_py",
+ srcs = ["__init__.py"] + glob(["python/ops/*.py"]),
+ dso = [":python/ops/_resampler_ops.so"],
+ kernels = [
+ ":resampler_ops_kernels",
+ ":resampler_ops_op_lib",
+ ],
+ visibility = ["//visibility:public"],
+ deps = [
+ ":resampler_ops",
+ "//tensorflow/contrib/util:util_py",
+ "//tensorflow/python:framework_ops",
+ "//tensorflow/python:platform",
+ "//tensorflow/python:util",
+ ],
+)
+
+tf_kernel_library(
+ name = "resampler_ops_kernels",
+ prefix = "resampler_ops",
+ deps = [
+ "//tensorflow/core:framework",
+ "//tensorflow/core:lib",
+ ],
+ alwayslink = 1,
+)
+
+tf_custom_op_library(
+ name = "python/ops/_resampler_ops.so",
+ srcs = [
+ "kernels/resampler_ops.cc",
+ "kernels/resampler_ops.h",
+ "ops/resampler_ops.cc",
+ ],
+ gpu_srcs = [
+ "kernels/resampler_ops_gpu.cu.cc",
+ "kernels/resampler_ops.h",
+ ],
+)
+
+tf_gen_op_libs(
+ op_lib_names = [
+ "resampler_ops",
+ ],
+)
+
+tf_gen_op_wrapper_py(
+ name = "resampler_ops",
+ deps = [":resampler_ops_op_lib"],
+)
+
+cuda_py_test(
+ name = "resampler_ops_test",
+ size = "small",
+ srcs = ["python/ops/resampler_ops_test.py"],
+ additional_deps = [
+ ":resampler_py",
+ "//tensorflow/python:client_testlib",
+ "//tensorflow/python:constant_op",
+ "//tensorflow/python:dtypes",
+ "//tensorflow/python:errors",
+ "//tensorflow/python:array_ops",
+ ],
+)
+
+filegroup(
+ name = "all_files",
+ srcs = glob(
+ ["**/*"],
+ exclude = [
+ "**/OWNERS",
+ ],
+ ),
+ visibility = ["//tensorflow:__subpackages__"],
+)
diff --git a/tensorflow/contrib/resampler/__init__.py b/tensorflow/contrib/resampler/__init__.py
new file mode 100644
index 0000000000..3e04e5762d
--- /dev/null
+++ b/tensorflow/contrib/resampler/__init__.py
@@ -0,0 +1,26 @@
+# Copyright 2016 The TensorFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+"""Ops and modules related to resampler."""
+
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+
+# pylint: disable=wildcard-import
+from tensorflow.contrib.resampler.python.ops.resampler_ops import *
+from tensorflow.python.util.all_util import remove_undocumented
+
+remove_undocumented(__name__, ["resampler"])
diff --git a/tensorflow/contrib/resampler/kernels/resampler_ops.cc b/tensorflow/contrib/resampler/kernels/resampler_ops.cc
new file mode 100644
index 0000000000..afc8bcd446
--- /dev/null
+++ b/tensorflow/contrib/resampler/kernels/resampler_ops.cc
@@ -0,0 +1,465 @@
+// Copyright 2017 The Sonnet Authors. All Rights Reserved.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+// =============================================================================
+
+#define EIGEN_USE_THREADS
+
+#include "tensorflow/contrib/resampler/kernels/resampler_ops.h"
+
+#include <algorithm>
+#include <cmath>
+#include <memory>
+
+#include "tensorflow/core/framework/op_kernel.h"
+#include "tensorflow/core/framework/register_types.h"
+#include "tensorflow/core/framework/tensor.h"
+#include "tensorflow/core/framework/tensor_shape.h"
+#include "tensorflow/core/framework/types.h"
+#include "tensorflow/core/util/work_sharder.h"
+
+namespace tensorflow {
+
+using CPUDevice = Eigen::ThreadPoolDevice;
+using GPUDevice = Eigen::GpuDevice;
+
+namespace functor {
+
+template <typename T>
+struct Resampler2DFunctor<CPUDevice, T>{
+ void operator ()(::tensorflow::OpKernelContext* ctx,
+ const CPUDevice& d,
+ const T* __restrict__ data,
+ const T* __restrict__ warp,
+ T* __restrict__ output,
+ const int batch_size,
+ const int data_height,
+ const int data_width,
+ const int data_channels,
+ const int num_sampling_points){
+ const int warp_batch_stride = num_sampling_points * 2;
+ const int data_batch_stride = data_height * data_width * data_channels;
+ const int output_batch_stride = num_sampling_points * data_channels;
+ const T zero = static_cast<T>(0.0);
+ const T one = static_cast<T>(1.0);
+
+ auto resample_batches = [&](const int start, const int limit) {
+ for (int batch_id = start; batch_id < limit; ++batch_id) {
+ // Utility lambda to access data point and set output values.
+ // The functions take care of performing the relevant pointer
+ // arithmetics abstracting away the low level details in the
+ // main loop over samples. Note that data is stored in NHWC format.
+ auto set_output = [&](const int sample_id,
+ const int channel,
+ const T value) {
+ output[batch_id * output_batch_stride +
+ sample_id * data_channels +
+ channel] = value;
+ };
+
+ auto get_data_point = [&](const int x,
+ const int y,
+ const int chan) {
+ const bool point_is_in_range =
+ (x >= 0 && y >= 0 && x <= data_width - 1 && y <= data_height - 1);
+ return point_is_in_range
+ ? data[batch_id * data_batch_stride +
+ data_channels * (y * data_width + x) +
+ chan]
+ : zero;
+ };
+
+ for (int sample_id = 0; sample_id < num_sampling_points; ++sample_id) {
+ const T x = warp[batch_id * warp_batch_stride + sample_id * 2];
+ const T y = warp[batch_id * warp_batch_stride + sample_id * 2 + 1];
+ // The interpolation function:
+ // a) implicitly pads the input data with 0s (hence the unusual checks
+ // with {x,y} > -1)
+ // b) returns 0 when sampling outside the (padded) image.
+ // The effect is that the sampled signal smoothly goes to 0 outside
+ // the original input domain, rather than presenting a jump
+ // discontinuity at the image boundaries.
+ if (x > static_cast<T>(-1.0) &&
+ y > static_cast<T>(-1.0) &&
+ x < static_cast<T>(data_width) &&
+ y < static_cast<T>(data_height)) {
+ // Precompute floor (f) and ceil (c) values for x and y.
+ const int fx = std::floor(static_cast<float>(x));
+ const int fy = std::floor(static_cast<float>(y));
+ const int cx = fx + 1;
+ const int cy = fy + 1;
+ const T dx = static_cast<T>(cx) - x;
+ const T dy = static_cast<T>(cy) - y;
+
+ for (int chan = 0; chan < data_channels; ++chan) {
+ const T img_fxfy = dx * dy * get_data_point(fx, fy, chan);
+ const T img_cxcy = (one - dx) * (one - dy) *
+ get_data_point(cx, cy, chan);
+ const T img_fxcy = dx * (one - dy) *
+ get_data_point(fx, cy, chan);
+ const T img_cxfy = (one - dx) * dy *
+ get_data_point(cx, fy, chan);
+ set_output(sample_id, chan,
+ img_fxfy + img_cxcy + img_fxcy + img_cxfy);
+ }
+ } else {
+ for (int chan = 0; chan < data_channels; ++chan) {
+ set_output(sample_id, chan, zero);
+ }
+ }
+ }
+ }
+ };
+ // Rough estimate of work for each batch entry.
+ // From third_party/tensorflow/core/util/work_sharder.cc we gather that an
+ // estimate of the cost of each work unit is needed to correclty shard the
+ // workload. Shard assumes each cost unit is 1ns, minimum cost per shard
+ // being 10us.
+ const int64 cost = static_cast<int64>(num_sampling_points) *
+ data_channels * 1000;
+ auto worker_threads = *(ctx->device()->tensorflow_cpu_worker_threads());
+ ::tensorflow::Shard(worker_threads.num_threads, worker_threads.workers,
+ batch_size, cost, resample_batches);
+ }
+};
+
+} // namespace functor
+
+template <typename Device, typename T>
+class ResamplerOp : public ::tensorflow::OpKernel {
+ public:
+ explicit ResamplerOp(::tensorflow::OpKernelConstruction* context) :
+ ::tensorflow::OpKernel(context) {}
+
+ void Compute(::tensorflow::OpKernelContext* ctx) override {
+ const ::tensorflow::Tensor& data = ctx->input(0);
+ const ::tensorflow::Tensor& warp = ctx->input(1);
+
+ const ::tensorflow::TensorShape& data_shape = data.shape();
+ OP_REQUIRES(ctx, data_shape.dims() == 4,
+ ::tensorflow::errors::Unimplemented(
+ "Only bilinear interpolation is currently supported. The "
+ "input data shape must be [batch_size, data_height, "
+ "data_width, data_channels], but is: ",
+ data_shape.DebugString()));
+ const ::tensorflow::TensorShape& warp_shape = warp.shape();
+ OP_REQUIRES(ctx,
+ ::tensorflow::TensorShapeUtils::IsMatrixOrHigher(warp_shape),
+ ::tensorflow::errors::InvalidArgument(
+ "warp should be at least a matrix, got shape ",
+ warp_shape.DebugString()));
+ OP_REQUIRES(ctx, warp_shape.dim_size(warp_shape.dims()-1) == 2,
+ ::tensorflow::errors::Unimplemented(
+ "Only bilinear interpolation is supported, warping "
+ "coordinates must be 2D; warp shape last entry should be "
+ "2, but shape vector is: ", warp_shape.DebugString()));
+ OP_REQUIRES(ctx, data_shape.dim_size(0) == warp_shape.dim_size(0),
+ ::tensorflow::errors::InvalidArgument(
+ "Batch size of data and warp tensor must be the same, but "
+ "input shapes are: ", data_shape.DebugString(), ", ",
+ warp_shape.DebugString()));
+ const int batch_size = data_shape.dim_size(0);
+ const int data_height = data_shape.dim_size(1);
+ const int data_width = data_shape.dim_size(2);
+ const int data_channels = data_shape.dim_size(3);
+ ::tensorflow::TensorShape output_shape = warp.shape();
+ output_shape.set_dim(output_shape.dims() - 1, data_channels);
+ const int num_sampling_points = warp.NumElements() / batch_size / 2;
+ ::tensorflow::Tensor* output = nullptr;
+ OP_REQUIRES_OK(ctx, ctx->allocate_output(0, output_shape, &output));
+
+ // Execute kernel only for nonempty output; otherwise Eigen crashes on GPU.
+ if (num_sampling_points > 0) {
+ functor::Resampler2DFunctor<Device, T>()(ctx,
+ ctx->eigen_device<Device>(),
+ data.flat<T>().data(),
+ warp.flat<T>().data(),
+ output->flat<T>().data(),
+ batch_size,
+ data_height,
+ data_width,
+ data_channels,
+ num_sampling_points);
+ }
+ }
+
+ private:
+ TF_DISALLOW_COPY_AND_ASSIGN(ResamplerOp);
+};
+
+
+#define REGISTER(TYPE) \
+ REGISTER_KERNEL_BUILDER( \
+ Name("Resampler") \
+ .Device(DEVICE_CPU) \
+ .TypeConstraint<TYPE>("T"), \
+ ResamplerOp<CPUDevice, TYPE>);
+
+TF_CALL_half(REGISTER);
+TF_CALL_float(REGISTER);
+TF_CALL_double(REGISTER);
+#undef REGISTER
+
+#if GOOGLE_CUDA
+#define REGISTER(TYPE) \
+ REGISTER_KERNEL_BUILDER(Name("Resampler") \
+ .Device(DEVICE_GPU) \
+ .TypeConstraint<TYPE>("T"), \
+ ResamplerOp<GPUDevice, TYPE>)
+TF_CALL_float(REGISTER);
+TF_CALL_double(REGISTER);
+#undef REGISTER
+#endif // GOOGLE_CUDA
+
+
+namespace functor {
+
+template <typename T>
+struct ResamplerGrad2DFunctor<CPUDevice, T>{
+ void operator ()(::tensorflow::OpKernelContext* ctx,
+ const CPUDevice& d,
+ const T* __restrict__ data,
+ const T* __restrict__ warp,
+ const T* __restrict__ grad_output,
+ T* __restrict__ grad_data,
+ T* __restrict__ grad_warp,
+ const int batch_size,
+ const int data_height,
+ const int data_width,
+ const int data_channels,
+ const int num_sampling_points){
+ // Set gradients to 0, because the kernel incrementally updates the
+ // tensor entries by adding partial contributions.
+ const int resampler_output_size = batch_size * num_sampling_points *
+ data_channels;
+ const int grad_warp_size = resampler_output_size / data_channels * 2;
+ const int grad_data_size = data_height * data_width * data_channels *
+ batch_size;
+ memset(grad_data, 0, sizeof(T) * grad_data_size);
+ memset(grad_warp, 0, sizeof(T) * grad_warp_size);
+
+ const auto&& data_batch_stride = data_height * data_width * data_channels;
+ const auto&& warp_batch_stride = num_sampling_points * 2;
+ const int output_batch_stride = num_sampling_points * data_channels;
+ const T zero = static_cast<T>(0.0);
+ const T one = static_cast<T>(1.0);
+
+ auto update_grads_for_batches = [&](const int start, const int limit) {
+ for (int batch_id = start; batch_id < limit; ++batch_id) {
+ // Utility lambdas to access data and update gradient tensors.
+ // The functions take care of performing the relevant pointer
+ // arithmetics abstracting away the low level details in the
+ // main loop over samples. Note that data is stored in NHWC format.
+ auto get_data_point = [&](const int x,
+ const int y,
+ const int chan) {
+ const bool point_is_in_range =
+ (x >= 0 && y >= 0 && x <= data_width - 1 && y <= data_height - 1);
+ return point_is_in_range
+ ? data[batch_id * data_batch_stride +
+ data_channels * (y * data_width + x) +
+ chan]
+ : zero;
+ };
+
+ auto update_grad_data = [&](const int x, const int y, const int chan,
+ const T value) {
+ const bool point_is_in_range =
+ (x >= 0 && y >= 0 && x <= data_width - 1 && y <= data_height - 1);
+ if (point_is_in_range){
+ grad_data[batch_id * data_batch_stride +
+ data_channels * (y * data_width + x) +
+ chan] += value;
+ }
+ };
+
+ auto update_grad_warp = [&](const int sample_id,
+ const int channel,
+ const T value) {
+ grad_warp[batch_id * warp_batch_stride +
+ sample_id * 2 +
+ channel] += value;
+ };
+
+ for (int sample_id = 0; sample_id < num_sampling_points; ++sample_id) {
+ const T x = warp[batch_id * warp_batch_stride + sample_id * 2];
+ const T y = warp[batch_id * warp_batch_stride + sample_id * 2 + 1];
+ // The interpolation function whose gradient this function implements:
+ // a) implicitly pads the input data with 0s (hence the unusual checks
+ // with {x,y} > -1)
+ // b) returns 0 when sampling outside the (padded) image.
+ // The effect is that the sampled signal smoothly goes to 0 outside
+ // the original input domain, rather than presenting a jump
+ // discontinuity at the image boundaries.
+ if (x > static_cast<T>(-1.0) &&
+ y > static_cast<T>(-1.0) &&
+ x < static_cast<T>(data_width) &&
+ y < static_cast<T>(data_height)) {
+ // Precompute floor (f) and ceil (c) values for x and y.
+ const int fx = std::floor(static_cast<float>(x));
+ const int fy = std::floor(static_cast<float>(y));
+ const int cx = fx + 1;
+ const int cy = fy + 1;
+ const T dx = static_cast<T>(cx) - x;
+ const T dy = static_cast<T>(cy) - y;
+
+ for (int chan = 0; chan < data_channels; ++chan) {
+ const T grad_output_value =
+ grad_output[batch_id * output_batch_stride +
+ sample_id * data_channels +
+ chan];
+ const T img_fxfy = get_data_point(fx, fy, chan);
+ const T img_cxcy = get_data_point(cx, cy, chan);
+ const T img_fxcy = get_data_point(fx, cy, chan);
+ const T img_cxfy = get_data_point(cx, fy, chan);
+
+ // Update partial gradients wrt relevant warp field entries
+ update_grad_warp(sample_id, 0,
+ grad_output_value *
+ ((one - dy) * (img_cxcy - img_fxcy) +
+ dy * (img_cxfy - img_fxfy)));
+
+ update_grad_warp(sample_id, 1,
+ grad_output_value *
+ ((one - dx) * (img_cxcy - img_cxfy) +
+ dx * (img_fxcy - img_fxfy)));
+
+ // Update partial gradients wrt sampled data
+ update_grad_data(fx, fy, chan,
+ grad_output_value * dx * dy);
+ update_grad_data(cx, cy, chan,
+ grad_output_value * (one - dx) * (one - dy));
+ update_grad_data(fx, cy, chan,
+ grad_output_value * dx * (one - dy));
+ update_grad_data(cx, fy, chan,
+ grad_output_value * (one - dx) * dy);
+ }
+ }
+ }
+ }
+ };
+ // Rough estimate of work for each batch entry.
+ // From third_party/tensorflow/core/util/work_sharder.cc we gather that an
+ // estimate of the cost of each work unit is needed to correctly shard the
+ // workload. Shard assumes each cost unit is 1ns, minimum cost per shard
+ // being 10us.
+ // TODO(fviola): Check out if there is a better way of doing this.
+ auto worker_threads = *(ctx->device()->tensorflow_cpu_worker_threads());
+ const int64 cost = static_cast<int64>(num_sampling_points) *
+ data_channels * 1000;
+ ::tensorflow::Shard(worker_threads.num_threads, worker_threads.workers,
+ batch_size, cost, update_grads_for_batches);
+ }
+};
+
+} // namespace functor
+
+
+template <typename Device, typename T>
+class ResamplerGradOp : public ::tensorflow::OpKernel {
+ public:
+ explicit ResamplerGradOp(::tensorflow::OpKernelConstruction* context) :
+ ::tensorflow::OpKernel(context) {}
+
+ void Compute(::tensorflow::OpKernelContext* ctx) override {
+ const ::tensorflow::Tensor& data = ctx->input(0);
+ const ::tensorflow::Tensor& warp = ctx->input(1);
+ const ::tensorflow::Tensor& grad_output = ctx->input(2);
+
+ const ::tensorflow::TensorShape& data_shape = data.shape();
+ OP_REQUIRES(ctx, data_shape.dims() == 4,
+ ::tensorflow::errors::Unimplemented(
+ "Only bilinear interpolation is supported, the input data "
+ "tensor must be a batch of 2d data; data shape should have "
+ "4 entries corresponding to [batch_size, data_height, "
+ "data_width, data_channels], but is: ",
+ data_shape.DebugString()));
+ const int batch_size = data_shape.dim_size(0);
+ const int data_height = data_shape.dim_size(1);
+ const int data_width = data_shape.dim_size(2);
+ const int data_channels = data_shape.dim_size(3);
+ const ::tensorflow::TensorShape& warp_shape = warp.shape();
+ OP_REQUIRES(ctx,
+ ::tensorflow::TensorShapeUtils::IsMatrixOrHigher(warp_shape),
+ ::tensorflow::errors::InvalidArgument(
+ "warp should be at least a matrix, got shape ",
+ warp_shape.DebugString()));
+ OP_REQUIRES(ctx, warp_shape.dim_size(warp_shape.dims()-1) == 2,
+ ::tensorflow::errors::Unimplemented(
+ "Only bilinear interpolation is supported, warping "
+ "coordinates must be 2D; warp shape last entry should be "
+ "2, but shape vector is: ",
+ warp_shape.DebugString()));
+ const ::tensorflow::TensorShape& grad_output_shape = grad_output.shape();
+ ::tensorflow::TensorShape resampler_output_shape = warp.shape();
+ resampler_output_shape.set_dim(resampler_output_shape.dims() - 1,
+ data_channels);
+ OP_REQUIRES(ctx, grad_output_shape == resampler_output_shape,
+ ::tensorflow::errors::InvalidArgument(
+ "grad_output shape is not consistent with data and warp "
+ "shapes; it should be ",
+ resampler_output_shape.DebugString(), " but is ",
+ grad_output_shape.DebugString()))
+ const int num_sampling_points = warp.NumElements() / batch_size / 2;
+ ::tensorflow::Tensor* grad_data = nullptr;
+ ::tensorflow::Tensor* grad_warp = nullptr;
+ OP_REQUIRES_OK(ctx, ctx->allocate_output(0, data.shape(), &grad_data));
+ OP_REQUIRES_OK(ctx, ctx->allocate_output(1, warp.shape(), &grad_warp));
+ // Execute kernel only for nonempty output; otherwise Eigen crashes on GPU.
+ if (num_sampling_points > 0) {
+ functor::ResamplerGrad2DFunctor<Device, T>()(ctx,
+ ctx->eigen_device<Device>(),
+ data.flat<T>().data(),
+ warp.flat<T>().data(),
+ grad_output.flat<T>().data(),
+ grad_data->flat<T>().data(),
+ grad_warp->flat<T>().data(),
+ batch_size,
+ data_height,
+ data_width,
+ data_channels,
+ num_sampling_points);
+ }
+ }
+
+ private:
+ TF_DISALLOW_COPY_AND_ASSIGN(ResamplerGradOp);
+};
+
+#define REGISTER(TYPE) \
+ REGISTER_KERNEL_BUILDER( \
+ Name("ResamplerGrad") \
+ .Device(DEVICE_CPU) \
+ .TypeConstraint<TYPE>("T"), \
+ ResamplerGradOp<CPUDevice, TYPE>);
+
+TF_CALL_half(REGISTER);
+TF_CALL_float(REGISTER);
+TF_CALL_double(REGISTER);
+#undef REGISTER
+
+#if GOOGLE_CUDA
+#define REGISTER(TYPE) \
+ REGISTER_KERNEL_BUILDER(Name("ResamplerGrad") \
+ .Device(DEVICE_GPU) \
+ .TypeConstraint<TYPE>("T"), \
+ ResamplerGradOp<GPUDevice, TYPE>)
+// Disable half and double precision since atomicAdds are not supported
+// TF_CALL_half(REGISTER);
+// TF_CALL_double(REGISTER);
+TF_CALL_float(REGISTER);
+
+#undef REGISTER
+#endif // GOOGLE_CUDA
+
+} // namespace tensorflow
diff --git a/tensorflow/contrib/resampler/kernels/resampler_ops.h b/tensorflow/contrib/resampler/kernels/resampler_ops.h
new file mode 100644
index 0000000000..8258ecaf5d
--- /dev/null
+++ b/tensorflow/contrib/resampler/kernels/resampler_ops.h
@@ -0,0 +1,68 @@
+// Copyright 2017 The Sonnet Authors. All Rights Reserved.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+// =============================================================================
+
+#ifndef THIRD_PARTY_TENSORFLOW_CONTRIB_RESAMPLER_KERNELS_RESAMPLER_OPS_H_
+#define THIRD_PARTY_TENSORFLOW_CONTRIB_RESAMPLER_KERNELS_RESAMPLER_OPS_H_
+
+#if PLATFORM_WINDOWS
+#define __restrict__ __restrict
+#endif
+
+namespace tensorflow {
+class OpKernelContext;
+}
+
+namespace tensorflow {
+namespace functor {
+
+// Helper functor for the Resampler Op in 2D
+template <typename Device, typename T>
+struct Resampler2DFunctor{
+ void operator ()(::tensorflow::OpKernelContext* ctx,
+ const Device& d,
+ const T* __restrict__ data,
+ const T* __restrict__ warp,
+ T* __restrict__ output,
+ const int batch_size,
+ const int data_height,
+ const int data_width,
+ const int data_channels,
+ const int num_sampling_points);
+};
+
+
+// Helper functor for the Resampler Gradient Op in 2D
+template <typename Device, typename T>
+struct ResamplerGrad2DFunctor{
+ void operator ()(::tensorflow::OpKernelContext* ctx,
+ const Device& d,
+ const T* __restrict__ data,
+ const T* __restrict__ warp,
+ const T* __restrict__ grad_output,
+ T* __restrict__ grad_data,
+ T* __restrict__ grad_warp,
+ const int batch_size,
+ const int data_height,
+ const int data_width,
+ const int data_channels,
+ const int num_sampling_points);
+};
+
+
+} // namespace functor
+} // namespace tensorflow
+
+
+#endif // THIRD_PARTY_TENSORFLOW_CONTRIB_RESAMPLER_KERNELS_RESAMPLER_OPS_H_
diff --git a/tensorflow/contrib/resampler/kernels/resampler_ops_gpu.cu.cc b/tensorflow/contrib/resampler/kernels/resampler_ops_gpu.cu.cc
new file mode 100644
index 0000000000..636847a212
--- /dev/null
+++ b/tensorflow/contrib/resampler/kernels/resampler_ops_gpu.cu.cc
@@ -0,0 +1,310 @@
+// Copyright 2016 The Sonnet Authors. All Rights Reserved.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+// =============================================================================
+
+#if GOOGLE_CUDA
+
+#define EIGEN_USE_GPU
+
+#include "tensorflow/contrib/resampler/kernels/resampler_ops.h"
+
+#include <stdio.h>
+#include <cmath>
+
+#include "tensorflow/core/framework/register_types.h"
+#include "tensorflow/core/util/cuda_kernel_helper.h"
+
+namespace tensorflow {
+
+using GPUDevice = Eigen::GpuDevice;
+
+namespace {
+
+#define GET_DATA_POINT(x, y) \
+ data[batch_id * data_batch_stride + \
+ data_channels * (y * data_width + x) + \
+ chan]
+
+template <typename T>
+__global__ void Resampler2DKernel(const T* __restrict__ data,
+ const T* __restrict__ warp,
+ T* __restrict__ output,
+ const int batch_size,
+ const int data_height,
+ const int data_width,
+ const int data_channels,
+ const int num_sampling_points) {
+ const int output_data_size = batch_size * num_sampling_points * data_channels;
+ CUDA_1D_KERNEL_LOOP(index, output_data_size) {
+ const int out_index = index;
+
+ // Get (idxSample, channel, point) from the index.
+ // Use this formula
+ // index = batch_id * num_sampling_points * num_chans +
+ // sample_id * num_chans + chan_id,
+ // with sample_id = [0, ... ,num_sampling_points)
+ const int data_batch_stride = data_height * data_width * data_channels;
+ const int warp_batch_stride = num_sampling_points * 2;
+ const int output_batch_stride = num_sampling_points * data_channels;
+
+ const int batch_id = index / output_batch_stride;
+ const int index_in_batch = index % output_batch_stride;
+ const int chan = index_in_batch % data_channels;
+ const int sample_id = index_in_batch / data_channels;
+
+ // Get coords of 2D point where data will be resampled
+ const T x = warp[batch_id * warp_batch_stride + sample_id * 2];
+ const T y = warp[batch_id * warp_batch_stride + sample_id * 2 + 1];
+ const T zero = static_cast<T>(0.0);
+ const T one = static_cast<T>(1.0);
+ // The interpolation function:
+ // a) implicitly pads the input data with 0s (hence the unusual checks
+ // with {x,y} > -1)
+ // b) returns 0 when sampling outside the (padded) image.
+ // The effect is that the sampled signal smoothly goes to 0 outside
+ // the original input domain, rather than presenting a jump
+ // discontinuity at the image boundaries.
+ if (x > static_cast<T>(-1.0) &&
+ y > static_cast<T>(-1.0) &&
+ x < static_cast<T>(data_width) &&
+ y < static_cast<T>(data_height)) {
+ // Precompute floor (f) and ceil (c) values for x and y.
+ const int fx = std::floor(static_cast<float>(x));
+ const int fy = std::floor(static_cast<float>(y));
+ const int cx = fx + 1;
+ const int cy = fy + 1;
+ const T dx = static_cast<T>(cx) - x;
+ const T dy = static_cast<T>(cy) - y;
+
+ const T img_fxfy = (fx >= 0 && fy >= 0)
+ ? dx * dy * GET_DATA_POINT(fx, fy)
+ : zero;
+
+ const T img_cxcy = (cx <= data_width - 1 && cy <= data_height - 1)
+ ? (one - dx) * (one - dy) * GET_DATA_POINT(cx, cy)
+ : zero;
+
+ const T img_fxcy = (fx >= 0 && cy <= data_height - 1)
+ ? dx * (one - dy) * GET_DATA_POINT(fx, cy)
+ : zero;
+
+ const T img_cxfy = (cx <= data_width - 1 && fy >= 0)
+ ? (one - dx) * dy * GET_DATA_POINT(cx, fy)
+ : zero;
+
+ output[out_index] = img_fxfy + img_cxcy + img_fxcy + img_cxfy;
+ } else {
+ output[out_index] = zero;
+ }
+ }
+}
+
+} // namespace
+
+namespace functor {
+
+template <typename T>
+struct Resampler2DFunctor<GPUDevice, T>{
+ void operator ()(::tensorflow::OpKernelContext* ctx,
+ const GPUDevice& d,
+ const T* __restrict__ data,
+ const T* __restrict__ warp,
+ T* __restrict__ output,
+ const int batch_size,
+ const int data_height,
+ const int data_width,
+ const int data_channels,
+ const int num_sampling_points) {
+ const int output_data_size = batch_size * num_sampling_points * data_channels;
+ ::tensorflow::CudaLaunchConfig config =
+ ::tensorflow::GetCudaLaunchConfig(output_data_size, d);
+ Resampler2DKernel<T>
+ <<<config.block_count, config.thread_per_block, 0, d.stream()>>>(
+ data, warp, output, batch_size, data_height, data_width,
+ data_channels, num_sampling_points);
+ }
+};
+
+// TODO(fviola): gcudacc fails at compile time with Eigen::half.
+// template struct Resampler2DFunctor<GPUDevice, Eigen::half>;
+template struct Resampler2DFunctor<GPUDevice, float>;
+template struct Resampler2DFunctor<GPUDevice, double>;
+
+} // namespace functor
+
+namespace {
+
+#define UPDATE_GRAD_DATA_POINT(x, y, v) \
+ atomicAdd(grad_data + (batch_id * data_batch_stride + \
+ data_channels * (y * data_width + x) + \
+ chan), \
+ v)
+
+
+template <typename T>
+__global__ void ResamplerGrad2DKernel(const T* __restrict__ data,
+ const T* __restrict__ warp,
+ const T* __restrict__ grad_output,
+ T* __restrict__ grad_data,
+ T* __restrict__ grad_warp,
+ const int batch_size,
+ const int data_height,
+ const int data_width,
+ const int data_channels,
+ const int num_sampling_points) {
+ const int resampler_output_size = batch_size * num_sampling_points *
+ data_channels;
+ CUDA_1D_KERNEL_LOOP(index, resampler_output_size) {
+ const int out_index = index;
+
+ // Get (idxSample, channel, point) from the index.
+ // Use this formula
+ // index = batch_id * num_sampling_points * num_chans +
+ // sample_id * num_chans + chan_id,
+ // with sample_id = [0, ... ,num_sampling_points)
+ const int data_batch_stride = data_height * data_width * data_channels;
+ const int warp_batch_stride = num_sampling_points * 2;
+ const int output_batch_stride = num_sampling_points * data_channels;
+
+ const int batch_id = index / output_batch_stride;
+ const int index_in_batch = index % output_batch_stride;
+ const int chan = index_in_batch % data_channels;
+ const int sample_id = index_in_batch / data_channels;
+
+ // Get coords of 2D point where data will be resampled
+ const int warp_id_x = batch_id * warp_batch_stride + sample_id * 2;
+ const int warp_id_y = warp_id_x + 1;
+ const T x = warp[warp_id_x];
+ const T y = warp[warp_id_y];
+ const T zero = static_cast<T>(0.0);
+ const T one = static_cast<T>(1.0);
+
+ // Get grad output
+ const T grad_output_value = grad_output[out_index];
+ // The interpolation function whose gradient this kernel implements:
+ // a) implicitly pads the input data with 0s (hence the unusual checks
+ // with {x,y} > -1)
+ // b) returns 0 when sampling outside the (padded) image.
+ // The effect is that the sampled signal smoothly goes to 0 outside
+ // the original input domain, rather than presenting a jump
+ // discontinuity at the image boundaries.
+ if (x > static_cast<T>(-1.0) &&
+ y > static_cast<T>(-1.0) &&
+ x < static_cast<T>(data_width) &&
+ y < static_cast<T>(data_height)) {
+ // Precompute floor (f) and ceil (c) values for x and y.
+ const int fx = std::floor(static_cast<float>(x));
+ const int fy = std::floor(static_cast<float>(y));
+ const int cx = fx + 1;
+ const int cy = fy + 1;
+ const T dx = static_cast<T>(cx) - x;
+ const T dy = static_cast<T>(cy) - y;
+
+ const T img_fxfy = (fx >= 0 && fy >= 0)
+ ? GET_DATA_POINT(fx, fy)
+ : zero;
+
+ const T img_cxcy = (cx <= data_width - 1 && cy <= data_height - 1)
+ ? GET_DATA_POINT(cx, cy)
+ : zero;
+
+ const T img_fxcy = (fx >= 0 && cy <= data_height - 1)
+ ? GET_DATA_POINT(fx, cy)
+ : zero;
+
+ const T img_cxfy = (cx <= data_width - 1 && fy >= 0)
+ ? GET_DATA_POINT(cx, fy)
+ : zero;
+
+ // Update partial gradients wrt relevant warp field entries
+ atomicAdd(grad_warp + warp_id_x,
+ grad_output_value * ((one - dy) * (img_cxcy - img_fxcy) +
+ dy * (img_cxfy - img_fxfy)));
+ atomicAdd(grad_warp + warp_id_y,
+ grad_output_value * ((one - dx) * (img_cxcy - img_cxfy) +
+ dx * (img_fxcy - img_fxfy)));
+
+ // Update partial gradients wrt sampled data
+ if (fx >= 0 && fy >= 0) {
+ UPDATE_GRAD_DATA_POINT(fx, fy, grad_output_value * dx * dy);
+ }
+ if (cx <= data_width - 1 && cy <= data_height - 1) {
+ UPDATE_GRAD_DATA_POINT(cx, cy,
+ grad_output_value * (one - dx) * (one - dy));
+ }
+ if (fx >= 0 && cy <= data_height - 1) {
+ UPDATE_GRAD_DATA_POINT(fx, cy, grad_output_value * dx * (one - dy));
+ }
+ if (cx <= data_width - 1 && fy >= 0) {
+ UPDATE_GRAD_DATA_POINT(cx, fy, grad_output_value * (one - dx) * dy);
+ }
+ }
+ }
+}
+
+#undef GET_DATA_POINT
+#undef UPDATE_GRAD_DATA_POINT
+
+} // namespace
+
+namespace functor {
+
+template <typename T>
+struct ResamplerGrad2DFunctor<GPUDevice, T>{
+ void operator ()(::tensorflow::OpKernelContext* ctx,
+ const GPUDevice& d,
+ const T* __restrict__ data,
+ const T* __restrict__ warp,
+ const T* __restrict__ grad_output,
+ T* __restrict__ grad_data,
+ T* __restrict__ grad_warp,
+ const int batch_size,
+ const int data_height,
+ const int data_width,
+ const int data_channels,
+ const int num_sampling_points) {
+ // Set gradients to 0, because the kernel incrementally updates the
+ // tensor entries by adding partial contributions.
+ const int grad_warp_size = batch_size * num_sampling_points * 2;
+ const int grad_data_size = batch_size * data_height * data_width *
+ data_channels;
+
+ ::tensorflow::CudaLaunchConfig config =
+ ::tensorflow::GetCudaLaunchConfig(grad_warp_size, d);
+ ::tensorflow::SetZero
+ <<<config.block_count, config.thread_per_block, 0, d.stream()>>>(
+ grad_warp_size, grad_warp);
+
+ config = ::tensorflow::GetCudaLaunchConfig(grad_data_size, d);
+ ::tensorflow::SetZero
+ <<<config.block_count, config.thread_per_block, 0, d.stream()>>>(
+ grad_data_size, grad_data);
+
+ const int resampler_output_size = batch_size * num_sampling_points *
+ data_channels;
+ config = ::tensorflow::GetCudaLaunchConfig(resampler_output_size, d);
+ ResamplerGrad2DKernel<T>
+ <<<config.block_count, config.thread_per_block, 0, d.stream()>>>(
+ data, warp, grad_output, grad_data, grad_warp, batch_size,
+ data_height, data_width, data_channels, num_sampling_points);
+ }
+};
+
+template struct ResamplerGrad2DFunctor<GPUDevice, float>;
+
+} // namespace functor
+
+} // namespace tensorflow
+
+#endif // GOOGLE_CUDA
diff --git a/tensorflow/contrib/resampler/ops/resampler_ops.cc b/tensorflow/contrib/resampler/ops/resampler_ops.cc
new file mode 100644
index 0000000000..5ab212032e
--- /dev/null
+++ b/tensorflow/contrib/resampler/ops/resampler_ops.cc
@@ -0,0 +1,59 @@
+// Copyright 2017 The Sonnet Authors. All Rights Reserved.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+// =============================================================================
+
+#include "tensorflow/core/framework/op.h"
+#include "tensorflow/core/framework/shape_inference.h"
+
+namespace tensorflow {
+
+using ::tensorflow::shape_inference::InferenceContext;
+using ::tensorflow::shape_inference::ShapeHandle;
+
+REGISTER_OP("Resampler")
+ .Input("data: T")
+ .Input("warp: T")
+ .Output("output: T")
+ .Attr("T: {half, float, double}")
+ .SetShapeFn([](InferenceContext* c) {
+ ShapeHandle data;
+ ShapeHandle warp;
+ TF_RETURN_IF_ERROR(c->WithRankAtLeast(c->input(0), 1, &data));
+ TF_RETURN_IF_ERROR(c->WithRankAtLeast(c->input(1), 1, &warp));
+
+ ShapeHandle output; // will be warp[:-1] + [data[-1]]
+ TF_RETURN_IF_ERROR(c->Subshape(warp, 0, -1, &output));
+ TF_RETURN_IF_ERROR(
+ c->Concatenate(output, c->Vector(c->Dim(data, -1)), &output));
+
+ c->set_output(0, output);
+ return ::tensorflow::Status::OK();
+ })
+ .Doc(R"doc(Resampler op.)doc");
+
+REGISTER_OP("ResamplerGrad")
+ .Input("data: T")
+ .Input("warp: T")
+ .Input("grad_output: T")
+ .Output("grad_data: T")
+ .Output("grad_warp: T")
+ .Attr("T: {half, float, double}")
+ .SetShapeFn([](InferenceContext* c) {
+ c->set_output(0, c->input(0));
+ c->set_output(1, c->input(1));
+ return ::tensorflow::Status::OK();
+ })
+ .Doc(R"doc(Resampler Grad op.)doc");
+
+} // namespace tensorflow
diff --git a/tensorflow/contrib/resampler/python/__init__.py b/tensorflow/contrib/resampler/python/__init__.py
new file mode 100644
index 0000000000..c5ca3a623f
--- /dev/null
+++ b/tensorflow/contrib/resampler/python/__init__.py
@@ -0,0 +1,19 @@
+# Copyright 2016 The TensorFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+"""ops module."""
+
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
diff --git a/tensorflow/contrib/resampler/python/ops/resampler_ops.py b/tensorflow/contrib/resampler/python/ops/resampler_ops.py
new file mode 100644
index 0000000000..355d15f0c7
--- /dev/null
+++ b/tensorflow/contrib/resampler/python/ops/resampler_ops.py
@@ -0,0 +1,69 @@
+# pylint: disable=g-bad-file-header
+# Copyright 2017 The Sonnet Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ============================================================================
+
+"""Tensorflow op performing differentiable resampling."""
+
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+from tensorflow.contrib.resampler.ops import gen_resampler_ops
+from tensorflow.contrib.util import loader
+from tensorflow.python.framework import ops
+from tensorflow.python.platform import resource_loader
+
+_resampler_so = loader.load_op_library(
+ resource_loader.get_path_to_datafile("_resampler_ops.so"))
+
+
+def resampler(data, warp, name="resampler"):
+ """Resamples input data at user defined coordinates.
+
+ The resampler currently only supports bilinear interpolation of 2D data.
+
+ Args:
+ data: Tensor of shape `[batch_size, data_height, data_width,
+ data_num_channels]` containing 2D data that will be resampled.
+ warp: Tensor of minimum rank 2 containing the coordinates at which
+ resampling will be performed. Since only bilinear interpolation is
+ currently supported, the last dimension of the `warp` tensor must be 2.
+ name: Optional name of the op.
+
+ Returns:
+ Tensor of resampled values from `data`. The output tensor shape is
+ determined by the shape of the warp tensor. For example, if `data` is of
+ shape `[batch_size, data_height, data_width, data_num_channels]` and warp of
+ shape `[batch_size, dim_0, ... , dim_n, 2]` the output will be of shape
+ `[batch_size, dim_0, ... , dim_n, data_num_channels]`.
+
+ Raises:
+ ImportError: if the wrapper generated during compilation is not present when
+ the function is called.
+ """
+ with ops.name_scope(name, "resampler", [data, warp]):
+ data_tensor = ops.convert_to_tensor(data, name="data")
+ warp_tensor = ops.convert_to_tensor(warp, name="warp")
+ return gen_resampler_ops.resampler(data_tensor, warp_tensor)
+
+
+@ops.RegisterGradient("Resampler")
+def _resampler_grad(op, grad_output):
+ data, warp = op.inputs
+ grad_output_tensor = ops.convert_to_tensor(grad_output, name="grad_output")
+ return gen_resampler_ops.resampler_grad(data, warp, grad_output_tensor)
+
+
+ops.NotDifferentiable("ResamplerGrad")
diff --git a/tensorflow/contrib/resampler/python/ops/resampler_ops_test.py b/tensorflow/contrib/resampler/python/ops/resampler_ops_test.py
new file mode 100644
index 0000000000..6a4360150c
--- /dev/null
+++ b/tensorflow/contrib/resampler/python/ops/resampler_ops_test.py
@@ -0,0 +1,270 @@
+# pylint: disable=g-bad-file-header
+# Copyright 2017 The Sonnet Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ============================================================================
+
+"""Tests for contrib.resampler.python.ops.resampler_ops."""
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+import numpy as np
+from six.moves import xrange # pylint: disable=redefined-builtin
+
+from tensorflow.contrib import resampler
+from tensorflow.python.framework import constant_op
+from tensorflow.python.framework import dtypes
+from tensorflow.python.framework import errors_impl
+from tensorflow.python.ops import array_ops
+from tensorflow.python.platform import test
+
+
+def _bilinearly_interpolate(data, x, y):
+ """Performs bilinenar interpolation of grid data at user defined coordinates.
+
+ This interpolation function:
+ a) implicitly pads the input data with 0s.
+ b) returns 0 when sampling outside the (padded) image.
+ The effect is that the sampled signal smoothly goes to 0 outside the original
+ input domain, rather than producing a jump discontinuity at the image
+ boundaries.
+
+ Args:
+ data: numpy array of shape `[data_height, data_width]` containing data
+ samples assumed to be defined at the corresponding pixel coordinates.
+ x: numpy array of shape `[warp_height, warp_width]` containing x coordinates
+ at which interpolation will be performed.
+ y: numpy array of shape `[warp_height, warp_width]` containing y coordinates
+ at which interpolation will be performed.
+
+ Returns:
+ Numpy array of shape `[warp_height, warp_width]` containing interpolated
+ values.
+ """
+ shape = x.shape
+ x = np.asarray(x) + 1
+ y = np.asarray(y) + 1
+ data = np.lib.pad(data, 1, "constant", constant_values=0)
+
+ x_0 = np.floor(x).astype(int)
+ x_1 = x_0 + 1
+ y_0 = np.floor(y).astype(int)
+ y_1 = y_0 + 1
+
+ x_0 = np.clip(x_0, 0, data.shape[1] - 1)
+ x_1 = np.clip(x_1, 0, data.shape[1] - 1)
+ y_0 = np.clip(y_0, 0, data.shape[0] - 1)
+ y_1 = np.clip(y_1, 0, data.shape[0] - 1)
+
+ i_a = data[y_0, x_0]
+ i_b = data[y_1, x_0]
+ i_c = data[y_0, x_1]
+ i_d = data[y_1, x_1]
+
+ w_a = (x_1 - x) * (y_1 - y)
+ w_b = (x_1 - x) * (y - y_0)
+ w_c = (x - x_0) * (y_1 - y)
+ w_d = (x - x_0) * (y - y_0)
+
+ samples = (w_a * i_a + w_b * i_b + w_c * i_c + w_d * i_d)
+ samples.reshape(shape)
+
+ return samples
+
+
+def _make_warp(batch_size, warp_height, warp_width, dtype):
+ """Creates batch of warping coordinates."""
+ x, y = np.meshgrid(np.linspace(0, warp_width - 1, warp_width),
+ np.linspace(0, warp_height - 1, warp_height))
+ warp = np.concatenate((x.reshape([warp_height, warp_width, 1]),
+ y.reshape([warp_height, warp_width, 1])), 2)
+ warp = np.tile(warp.reshape([1, warp_height, warp_width, 2]),
+ [batch_size, 1, 1, 1])
+ warp += np.random.randn(*warp.shape)
+ return warp.astype(dtype)
+
+
+class ResamplerTest(test.TestCase):
+
+ def test_op_forward_pass_gpu_float32(self):
+ self._test_op_forward_pass(True, dtypes.float32, 1e-4)
+
+ def test_op_forward_pass_gpu_float64(self):
+ self._test_op_forward_pass(True, dtypes.float64, 1e-5)
+
+ def test_op_forward_pass_cpu_float16(self):
+ self._test_op_forward_pass(False, dtypes.float16, 1e-2)
+
+ def test_op_forward_pass_cpu_float32(self):
+ self._test_op_forward_pass(False, dtypes.float32, 1e-4)
+
+ def test_op_forward_pass_cpu_float64(self):
+ self._test_op_forward_pass(False, dtypes.float64, 1e-5)
+
+ def test_op_backward_pass_gpu_float32(self):
+ self._test_op_backward_pass(True, dtypes.float32, 1e-3)
+
+ def test_op_backward_pass_cpu_float16(self):
+ self._test_op_backward_pass(False, dtypes.float16, 1e-3)
+
+ def test_op_backward_pass_cpu_float32(self):
+ self._test_op_backward_pass(False, dtypes.float32, 1e-4)
+
+ def test_op_backward_pass_cpu_float64(self):
+ self._test_op_backward_pass(False, dtypes.float64, 1e-6)
+
+ def _test_op_forward_pass(self, on_gpu, dtype, tol):
+ np.random.seed(0)
+ data_width = 7
+ data_height = 9
+ data_channels = 5
+ warp_width = 4
+ warp_height = 8
+ batch_size = 10
+
+ warp = _make_warp(batch_size, warp_height, warp_width, dtype.as_numpy_dtype)
+ data_shape = (batch_size, data_height, data_width, data_channels)
+ data = np.random.rand(*data_shape).astype(dtype.as_numpy_dtype)
+
+ with self.test_session(use_gpu=on_gpu, force_gpu=False) as sess:
+ data_ph = array_ops.placeholder(dtype, shape=(None,) + data.shape[1:])
+ warp_ph = array_ops.placeholder(dtype, shape=(None,) + warp.shape[1:])
+ outputs = resampler.resampler(data=data_ph, warp=warp_ph)
+ self.assertEqual(outputs.get_shape().as_list(),
+ [None, warp_height, warp_width, data_channels])
+ out = sess.run(outputs, feed_dict={data_ph: data, warp_ph: warp})
+
+ # Generate reference output via bilinear interpolation in numpy
+ reference_output = np.zeros_like(out)
+ for batch in xrange(batch_size):
+ for c in xrange(data_channels):
+ reference_output[batch, :, :, c] = _bilinearly_interpolate(
+ data[batch, :, :, c],
+ warp[batch, :, :, 0],
+ warp[batch, :, :, 1])
+
+ self.assertAllClose(out, reference_output, rtol=tol, atol=tol)
+
+ def _test_op_backward_pass(self, on_gpu, dtype, tol):
+ np.random.seed(13)
+ data_width = 5
+ data_height = 4
+ data_channels = 3
+ warp_width = 2
+ warp_height = 6
+ batch_size = 10
+
+ warp = _make_warp(batch_size, warp_height, warp_width, dtype.as_numpy_dtype)
+ data_shape = (batch_size, data_height, data_width, data_channels)
+ data = np.random.rand(*data_shape).astype(dtype.as_numpy_dtype)
+
+ with self.test_session(use_gpu=on_gpu, force_gpu=False):
+ data_tensor = constant_op.constant(data)
+ warp_tensor = constant_op.constant(warp)
+ output_tensor = resampler.resampler(data=data_tensor, warp=warp_tensor)
+
+ grads = test.compute_gradient([data_tensor, warp_tensor], [
+ data_tensor.get_shape().as_list(),
+ warp_tensor.get_shape().as_list()
+ ], output_tensor, output_tensor.get_shape().as_list(), [data, warp])
+
+ if not on_gpu:
+ # On CPU we perform numerical differentiation at the best available
+ # precision, and compare against that. This is necessary for test to
+ # pass for float16.
+ data_tensor_64 = constant_op.constant(data, dtype=dtypes.float64)
+ warp_tensor_64 = constant_op.constant(warp, dtype=dtypes.float64)
+ output_tensor_64 = resampler.resampler(data=data_tensor_64,
+ warp=warp_tensor_64)
+ grads_64 = test.compute_gradient([data_tensor_64, warp_tensor_64], [
+ data_tensor.get_shape().as_list(),
+ warp_tensor.get_shape().as_list()
+ ], output_tensor_64, output_tensor.get_shape().as_list(), [data, warp])
+
+ for g, g_64 in zip(grads, grads_64):
+ self.assertLess(np.fabs(g[0] - g_64[1]).max(), tol)
+
+ else:
+ for g in grads:
+ self.assertLess(np.fabs(g[0] - g[1]).max(), tol)
+
+ def test_op_errors(self):
+ data_width = 7
+ data_height = 9
+ data_depth = 3
+ data_channels = 5
+ warp_width = 4
+ warp_height = 8
+ batch_size = 10
+
+ # Input data shape is not defined over a 2D grid, i.e. its shape is not like
+ # (batch_size, data_height, data_width, data_channels).
+ with self.test_session() as sess:
+ data_shape = (batch_size, data_height, data_width, data_depth,
+ data_channels)
+ data = np.zeros(data_shape)
+ warp_shape = (batch_size, warp_height, warp_width, 2)
+ warp = np.zeros(warp_shape)
+ outputs = resampler.resampler(constant_op.constant(data),
+ constant_op.constant(warp))
+
+ with self.assertRaisesRegexp(errors_impl.UnimplementedError,
+ "Only bilinear interpolation is currently "
+ "supported."):
+ sess.run(outputs)
+
+ # Warp tensor must be at least a matrix, with shape [batch_size, 2].
+ with self.test_session() as sess:
+ data_shape = (batch_size, data_height, data_width, data_channels)
+ data = np.zeros(data_shape)
+ warp_shape = (batch_size,)
+ warp = np.zeros(warp_shape)
+ outputs = resampler.resampler(constant_op.constant(data),
+ constant_op.constant(warp))
+
+ with self.assertRaisesRegexp(errors_impl.InvalidArgumentError,
+ "warp should be at least a matrix"):
+ sess.run(outputs)
+
+ # The batch size of the data and warp tensors must be the same.
+ with self.test_session() as sess:
+ data_shape = (batch_size, data_height, data_width, data_channels)
+ data = np.zeros(data_shape)
+ warp_shape = (batch_size+1, warp_height, warp_width, 2)
+ warp = np.zeros(warp_shape)
+ outputs = resampler.resampler(constant_op.constant(data),
+ constant_op.constant(warp))
+
+ with self.assertRaisesRegexp(errors_impl.InvalidArgumentError,
+ "Batch size of data and warp tensor"):
+ sess.run(outputs)
+
+ # The warp tensor must contain 2D coordinates, i.e. its shape last dimension
+ # must be 2.
+ with self.test_session() as sess:
+ data_shape = (batch_size, data_height, data_width, data_channels)
+ data = np.zeros(data_shape)
+ warp_shape = (batch_size, warp_height, warp_width, 3)
+ warp = np.zeros(warp_shape)
+ outputs = resampler.resampler(constant_op.constant(data),
+ constant_op.constant(warp))
+
+ with self.assertRaisesRegexp(errors_impl.UnimplementedError,
+ "Only bilinear interpolation is supported, "
+ "warping"):
+ sess.run(outputs)
+
+
+if __name__ == "__main__":
+ test.main()