diff options
author | A. Unique TensorFlower <gardener@tensorflow.org> | 2017-06-26 05:47:31 -0700 |
---|---|---|
committer | TensorFlower Gardener <gardener@tensorflow.org> | 2017-06-26 05:51:07 -0700 |
commit | a80c8b583fa3d619b358a91aefd05069227b8967 (patch) | |
tree | f7270dedb338cb1ef5306a63d963416a52483aa1 /tensorflow/contrib/resampler | |
parent | 7a06e0af350e3e61bbf0ebee66d79ab15808c575 (diff) |
Move resampler from sonnet to contrib.
PiperOrigin-RevId: 160134565
Diffstat (limited to 'tensorflow/contrib/resampler')
-rw-r--r-- | tensorflow/contrib/resampler/BUILD | 92 | ||||
-rw-r--r-- | tensorflow/contrib/resampler/__init__.py | 26 | ||||
-rw-r--r-- | tensorflow/contrib/resampler/kernels/resampler_ops.cc | 465 | ||||
-rw-r--r-- | tensorflow/contrib/resampler/kernels/resampler_ops.h | 68 | ||||
-rw-r--r-- | tensorflow/contrib/resampler/kernels/resampler_ops_gpu.cu.cc | 310 | ||||
-rw-r--r-- | tensorflow/contrib/resampler/ops/resampler_ops.cc | 59 | ||||
-rw-r--r-- | tensorflow/contrib/resampler/python/__init__.py | 19 | ||||
-rw-r--r-- | tensorflow/contrib/resampler/python/ops/resampler_ops.py | 69 | ||||
-rw-r--r-- | tensorflow/contrib/resampler/python/ops/resampler_ops_test.py | 270 |
9 files changed, 1378 insertions, 0 deletions
diff --git a/tensorflow/contrib/resampler/BUILD b/tensorflow/contrib/resampler/BUILD new file mode 100644 index 0000000000..1b9efd1ecd --- /dev/null +++ b/tensorflow/contrib/resampler/BUILD @@ -0,0 +1,92 @@ +licenses(["notice"]) # Apache 2.0 License + +exports_files(["LICENSE"]) + +package(default_visibility = ["//visibility:public"]) + +load( + "//tensorflow:tensorflow.bzl", + "tf_custom_op_library", + "tf_custom_op_py_library", + "tf_gen_op_libs", + "tf_gen_op_wrapper_py", + "tf_kernel_library", +) +load("//tensorflow:tensorflow.bzl", "cuda_py_test") + +tf_custom_op_py_library( + name = "resampler_py", + srcs = ["__init__.py"] + glob(["python/ops/*.py"]), + dso = [":python/ops/_resampler_ops.so"], + kernels = [ + ":resampler_ops_kernels", + ":resampler_ops_op_lib", + ], + visibility = ["//visibility:public"], + deps = [ + ":resampler_ops", + "//tensorflow/contrib/util:util_py", + "//tensorflow/python:framework_ops", + "//tensorflow/python:platform", + "//tensorflow/python:util", + ], +) + +tf_kernel_library( + name = "resampler_ops_kernels", + prefix = "resampler_ops", + deps = [ + "//tensorflow/core:framework", + "//tensorflow/core:lib", + ], + alwayslink = 1, +) + +tf_custom_op_library( + name = "python/ops/_resampler_ops.so", + srcs = [ + "kernels/resampler_ops.cc", + "kernels/resampler_ops.h", + "ops/resampler_ops.cc", + ], + gpu_srcs = [ + "kernels/resampler_ops_gpu.cu.cc", + "kernels/resampler_ops.h", + ], +) + +tf_gen_op_libs( + op_lib_names = [ + "resampler_ops", + ], +) + +tf_gen_op_wrapper_py( + name = "resampler_ops", + deps = [":resampler_ops_op_lib"], +) + +cuda_py_test( + name = "resampler_ops_test", + size = "small", + srcs = ["python/ops/resampler_ops_test.py"], + additional_deps = [ + ":resampler_py", + "//tensorflow/python:client_testlib", + "//tensorflow/python:constant_op", + "//tensorflow/python:dtypes", + "//tensorflow/python:errors", + "//tensorflow/python:array_ops", + ], +) + +filegroup( + name = "all_files", + srcs = glob( + ["**/*"], + exclude = [ + "**/OWNERS", + ], + ), + visibility = ["//tensorflow:__subpackages__"], +) diff --git a/tensorflow/contrib/resampler/__init__.py b/tensorflow/contrib/resampler/__init__.py new file mode 100644 index 0000000000..3e04e5762d --- /dev/null +++ b/tensorflow/contrib/resampler/__init__.py @@ -0,0 +1,26 @@ +# Copyright 2016 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Ops and modules related to resampler.""" + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + + +# pylint: disable=wildcard-import +from tensorflow.contrib.resampler.python.ops.resampler_ops import * +from tensorflow.python.util.all_util import remove_undocumented + +remove_undocumented(__name__, ["resampler"]) diff --git a/tensorflow/contrib/resampler/kernels/resampler_ops.cc b/tensorflow/contrib/resampler/kernels/resampler_ops.cc new file mode 100644 index 0000000000..afc8bcd446 --- /dev/null +++ b/tensorflow/contrib/resampler/kernels/resampler_ops.cc @@ -0,0 +1,465 @@ +// Copyright 2017 The Sonnet Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +// ============================================================================= + +#define EIGEN_USE_THREADS + +#include "tensorflow/contrib/resampler/kernels/resampler_ops.h" + +#include <algorithm> +#include <cmath> +#include <memory> + +#include "tensorflow/core/framework/op_kernel.h" +#include "tensorflow/core/framework/register_types.h" +#include "tensorflow/core/framework/tensor.h" +#include "tensorflow/core/framework/tensor_shape.h" +#include "tensorflow/core/framework/types.h" +#include "tensorflow/core/util/work_sharder.h" + +namespace tensorflow { + +using CPUDevice = Eigen::ThreadPoolDevice; +using GPUDevice = Eigen::GpuDevice; + +namespace functor { + +template <typename T> +struct Resampler2DFunctor<CPUDevice, T>{ + void operator ()(::tensorflow::OpKernelContext* ctx, + const CPUDevice& d, + const T* __restrict__ data, + const T* __restrict__ warp, + T* __restrict__ output, + const int batch_size, + const int data_height, + const int data_width, + const int data_channels, + const int num_sampling_points){ + const int warp_batch_stride = num_sampling_points * 2; + const int data_batch_stride = data_height * data_width * data_channels; + const int output_batch_stride = num_sampling_points * data_channels; + const T zero = static_cast<T>(0.0); + const T one = static_cast<T>(1.0); + + auto resample_batches = [&](const int start, const int limit) { + for (int batch_id = start; batch_id < limit; ++batch_id) { + // Utility lambda to access data point and set output values. + // The functions take care of performing the relevant pointer + // arithmetics abstracting away the low level details in the + // main loop over samples. Note that data is stored in NHWC format. + auto set_output = [&](const int sample_id, + const int channel, + const T value) { + output[batch_id * output_batch_stride + + sample_id * data_channels + + channel] = value; + }; + + auto get_data_point = [&](const int x, + const int y, + const int chan) { + const bool point_is_in_range = + (x >= 0 && y >= 0 && x <= data_width - 1 && y <= data_height - 1); + return point_is_in_range + ? data[batch_id * data_batch_stride + + data_channels * (y * data_width + x) + + chan] + : zero; + }; + + for (int sample_id = 0; sample_id < num_sampling_points; ++sample_id) { + const T x = warp[batch_id * warp_batch_stride + sample_id * 2]; + const T y = warp[batch_id * warp_batch_stride + sample_id * 2 + 1]; + // The interpolation function: + // a) implicitly pads the input data with 0s (hence the unusual checks + // with {x,y} > -1) + // b) returns 0 when sampling outside the (padded) image. + // The effect is that the sampled signal smoothly goes to 0 outside + // the original input domain, rather than presenting a jump + // discontinuity at the image boundaries. + if (x > static_cast<T>(-1.0) && + y > static_cast<T>(-1.0) && + x < static_cast<T>(data_width) && + y < static_cast<T>(data_height)) { + // Precompute floor (f) and ceil (c) values for x and y. + const int fx = std::floor(static_cast<float>(x)); + const int fy = std::floor(static_cast<float>(y)); + const int cx = fx + 1; + const int cy = fy + 1; + const T dx = static_cast<T>(cx) - x; + const T dy = static_cast<T>(cy) - y; + + for (int chan = 0; chan < data_channels; ++chan) { + const T img_fxfy = dx * dy * get_data_point(fx, fy, chan); + const T img_cxcy = (one - dx) * (one - dy) * + get_data_point(cx, cy, chan); + const T img_fxcy = dx * (one - dy) * + get_data_point(fx, cy, chan); + const T img_cxfy = (one - dx) * dy * + get_data_point(cx, fy, chan); + set_output(sample_id, chan, + img_fxfy + img_cxcy + img_fxcy + img_cxfy); + } + } else { + for (int chan = 0; chan < data_channels; ++chan) { + set_output(sample_id, chan, zero); + } + } + } + } + }; + // Rough estimate of work for each batch entry. + // From third_party/tensorflow/core/util/work_sharder.cc we gather that an + // estimate of the cost of each work unit is needed to correclty shard the + // workload. Shard assumes each cost unit is 1ns, minimum cost per shard + // being 10us. + const int64 cost = static_cast<int64>(num_sampling_points) * + data_channels * 1000; + auto worker_threads = *(ctx->device()->tensorflow_cpu_worker_threads()); + ::tensorflow::Shard(worker_threads.num_threads, worker_threads.workers, + batch_size, cost, resample_batches); + } +}; + +} // namespace functor + +template <typename Device, typename T> +class ResamplerOp : public ::tensorflow::OpKernel { + public: + explicit ResamplerOp(::tensorflow::OpKernelConstruction* context) : + ::tensorflow::OpKernel(context) {} + + void Compute(::tensorflow::OpKernelContext* ctx) override { + const ::tensorflow::Tensor& data = ctx->input(0); + const ::tensorflow::Tensor& warp = ctx->input(1); + + const ::tensorflow::TensorShape& data_shape = data.shape(); + OP_REQUIRES(ctx, data_shape.dims() == 4, + ::tensorflow::errors::Unimplemented( + "Only bilinear interpolation is currently supported. The " + "input data shape must be [batch_size, data_height, " + "data_width, data_channels], but is: ", + data_shape.DebugString())); + const ::tensorflow::TensorShape& warp_shape = warp.shape(); + OP_REQUIRES(ctx, + ::tensorflow::TensorShapeUtils::IsMatrixOrHigher(warp_shape), + ::tensorflow::errors::InvalidArgument( + "warp should be at least a matrix, got shape ", + warp_shape.DebugString())); + OP_REQUIRES(ctx, warp_shape.dim_size(warp_shape.dims()-1) == 2, + ::tensorflow::errors::Unimplemented( + "Only bilinear interpolation is supported, warping " + "coordinates must be 2D; warp shape last entry should be " + "2, but shape vector is: ", warp_shape.DebugString())); + OP_REQUIRES(ctx, data_shape.dim_size(0) == warp_shape.dim_size(0), + ::tensorflow::errors::InvalidArgument( + "Batch size of data and warp tensor must be the same, but " + "input shapes are: ", data_shape.DebugString(), ", ", + warp_shape.DebugString())); + const int batch_size = data_shape.dim_size(0); + const int data_height = data_shape.dim_size(1); + const int data_width = data_shape.dim_size(2); + const int data_channels = data_shape.dim_size(3); + ::tensorflow::TensorShape output_shape = warp.shape(); + output_shape.set_dim(output_shape.dims() - 1, data_channels); + const int num_sampling_points = warp.NumElements() / batch_size / 2; + ::tensorflow::Tensor* output = nullptr; + OP_REQUIRES_OK(ctx, ctx->allocate_output(0, output_shape, &output)); + + // Execute kernel only for nonempty output; otherwise Eigen crashes on GPU. + if (num_sampling_points > 0) { + functor::Resampler2DFunctor<Device, T>()(ctx, + ctx->eigen_device<Device>(), + data.flat<T>().data(), + warp.flat<T>().data(), + output->flat<T>().data(), + batch_size, + data_height, + data_width, + data_channels, + num_sampling_points); + } + } + + private: + TF_DISALLOW_COPY_AND_ASSIGN(ResamplerOp); +}; + + +#define REGISTER(TYPE) \ + REGISTER_KERNEL_BUILDER( \ + Name("Resampler") \ + .Device(DEVICE_CPU) \ + .TypeConstraint<TYPE>("T"), \ + ResamplerOp<CPUDevice, TYPE>); + +TF_CALL_half(REGISTER); +TF_CALL_float(REGISTER); +TF_CALL_double(REGISTER); +#undef REGISTER + +#if GOOGLE_CUDA +#define REGISTER(TYPE) \ + REGISTER_KERNEL_BUILDER(Name("Resampler") \ + .Device(DEVICE_GPU) \ + .TypeConstraint<TYPE>("T"), \ + ResamplerOp<GPUDevice, TYPE>) +TF_CALL_float(REGISTER); +TF_CALL_double(REGISTER); +#undef REGISTER +#endif // GOOGLE_CUDA + + +namespace functor { + +template <typename T> +struct ResamplerGrad2DFunctor<CPUDevice, T>{ + void operator ()(::tensorflow::OpKernelContext* ctx, + const CPUDevice& d, + const T* __restrict__ data, + const T* __restrict__ warp, + const T* __restrict__ grad_output, + T* __restrict__ grad_data, + T* __restrict__ grad_warp, + const int batch_size, + const int data_height, + const int data_width, + const int data_channels, + const int num_sampling_points){ + // Set gradients to 0, because the kernel incrementally updates the + // tensor entries by adding partial contributions. + const int resampler_output_size = batch_size * num_sampling_points * + data_channels; + const int grad_warp_size = resampler_output_size / data_channels * 2; + const int grad_data_size = data_height * data_width * data_channels * + batch_size; + memset(grad_data, 0, sizeof(T) * grad_data_size); + memset(grad_warp, 0, sizeof(T) * grad_warp_size); + + const auto&& data_batch_stride = data_height * data_width * data_channels; + const auto&& warp_batch_stride = num_sampling_points * 2; + const int output_batch_stride = num_sampling_points * data_channels; + const T zero = static_cast<T>(0.0); + const T one = static_cast<T>(1.0); + + auto update_grads_for_batches = [&](const int start, const int limit) { + for (int batch_id = start; batch_id < limit; ++batch_id) { + // Utility lambdas to access data and update gradient tensors. + // The functions take care of performing the relevant pointer + // arithmetics abstracting away the low level details in the + // main loop over samples. Note that data is stored in NHWC format. + auto get_data_point = [&](const int x, + const int y, + const int chan) { + const bool point_is_in_range = + (x >= 0 && y >= 0 && x <= data_width - 1 && y <= data_height - 1); + return point_is_in_range + ? data[batch_id * data_batch_stride + + data_channels * (y * data_width + x) + + chan] + : zero; + }; + + auto update_grad_data = [&](const int x, const int y, const int chan, + const T value) { + const bool point_is_in_range = + (x >= 0 && y >= 0 && x <= data_width - 1 && y <= data_height - 1); + if (point_is_in_range){ + grad_data[batch_id * data_batch_stride + + data_channels * (y * data_width + x) + + chan] += value; + } + }; + + auto update_grad_warp = [&](const int sample_id, + const int channel, + const T value) { + grad_warp[batch_id * warp_batch_stride + + sample_id * 2 + + channel] += value; + }; + + for (int sample_id = 0; sample_id < num_sampling_points; ++sample_id) { + const T x = warp[batch_id * warp_batch_stride + sample_id * 2]; + const T y = warp[batch_id * warp_batch_stride + sample_id * 2 + 1]; + // The interpolation function whose gradient this function implements: + // a) implicitly pads the input data with 0s (hence the unusual checks + // with {x,y} > -1) + // b) returns 0 when sampling outside the (padded) image. + // The effect is that the sampled signal smoothly goes to 0 outside + // the original input domain, rather than presenting a jump + // discontinuity at the image boundaries. + if (x > static_cast<T>(-1.0) && + y > static_cast<T>(-1.0) && + x < static_cast<T>(data_width) && + y < static_cast<T>(data_height)) { + // Precompute floor (f) and ceil (c) values for x and y. + const int fx = std::floor(static_cast<float>(x)); + const int fy = std::floor(static_cast<float>(y)); + const int cx = fx + 1; + const int cy = fy + 1; + const T dx = static_cast<T>(cx) - x; + const T dy = static_cast<T>(cy) - y; + + for (int chan = 0; chan < data_channels; ++chan) { + const T grad_output_value = + grad_output[batch_id * output_batch_stride + + sample_id * data_channels + + chan]; + const T img_fxfy = get_data_point(fx, fy, chan); + const T img_cxcy = get_data_point(cx, cy, chan); + const T img_fxcy = get_data_point(fx, cy, chan); + const T img_cxfy = get_data_point(cx, fy, chan); + + // Update partial gradients wrt relevant warp field entries + update_grad_warp(sample_id, 0, + grad_output_value * + ((one - dy) * (img_cxcy - img_fxcy) + + dy * (img_cxfy - img_fxfy))); + + update_grad_warp(sample_id, 1, + grad_output_value * + ((one - dx) * (img_cxcy - img_cxfy) + + dx * (img_fxcy - img_fxfy))); + + // Update partial gradients wrt sampled data + update_grad_data(fx, fy, chan, + grad_output_value * dx * dy); + update_grad_data(cx, cy, chan, + grad_output_value * (one - dx) * (one - dy)); + update_grad_data(fx, cy, chan, + grad_output_value * dx * (one - dy)); + update_grad_data(cx, fy, chan, + grad_output_value * (one - dx) * dy); + } + } + } + } + }; + // Rough estimate of work for each batch entry. + // From third_party/tensorflow/core/util/work_sharder.cc we gather that an + // estimate of the cost of each work unit is needed to correctly shard the + // workload. Shard assumes each cost unit is 1ns, minimum cost per shard + // being 10us. + // TODO(fviola): Check out if there is a better way of doing this. + auto worker_threads = *(ctx->device()->tensorflow_cpu_worker_threads()); + const int64 cost = static_cast<int64>(num_sampling_points) * + data_channels * 1000; + ::tensorflow::Shard(worker_threads.num_threads, worker_threads.workers, + batch_size, cost, update_grads_for_batches); + } +}; + +} // namespace functor + + +template <typename Device, typename T> +class ResamplerGradOp : public ::tensorflow::OpKernel { + public: + explicit ResamplerGradOp(::tensorflow::OpKernelConstruction* context) : + ::tensorflow::OpKernel(context) {} + + void Compute(::tensorflow::OpKernelContext* ctx) override { + const ::tensorflow::Tensor& data = ctx->input(0); + const ::tensorflow::Tensor& warp = ctx->input(1); + const ::tensorflow::Tensor& grad_output = ctx->input(2); + + const ::tensorflow::TensorShape& data_shape = data.shape(); + OP_REQUIRES(ctx, data_shape.dims() == 4, + ::tensorflow::errors::Unimplemented( + "Only bilinear interpolation is supported, the input data " + "tensor must be a batch of 2d data; data shape should have " + "4 entries corresponding to [batch_size, data_height, " + "data_width, data_channels], but is: ", + data_shape.DebugString())); + const int batch_size = data_shape.dim_size(0); + const int data_height = data_shape.dim_size(1); + const int data_width = data_shape.dim_size(2); + const int data_channels = data_shape.dim_size(3); + const ::tensorflow::TensorShape& warp_shape = warp.shape(); + OP_REQUIRES(ctx, + ::tensorflow::TensorShapeUtils::IsMatrixOrHigher(warp_shape), + ::tensorflow::errors::InvalidArgument( + "warp should be at least a matrix, got shape ", + warp_shape.DebugString())); + OP_REQUIRES(ctx, warp_shape.dim_size(warp_shape.dims()-1) == 2, + ::tensorflow::errors::Unimplemented( + "Only bilinear interpolation is supported, warping " + "coordinates must be 2D; warp shape last entry should be " + "2, but shape vector is: ", + warp_shape.DebugString())); + const ::tensorflow::TensorShape& grad_output_shape = grad_output.shape(); + ::tensorflow::TensorShape resampler_output_shape = warp.shape(); + resampler_output_shape.set_dim(resampler_output_shape.dims() - 1, + data_channels); + OP_REQUIRES(ctx, grad_output_shape == resampler_output_shape, + ::tensorflow::errors::InvalidArgument( + "grad_output shape is not consistent with data and warp " + "shapes; it should be ", + resampler_output_shape.DebugString(), " but is ", + grad_output_shape.DebugString())) + const int num_sampling_points = warp.NumElements() / batch_size / 2; + ::tensorflow::Tensor* grad_data = nullptr; + ::tensorflow::Tensor* grad_warp = nullptr; + OP_REQUIRES_OK(ctx, ctx->allocate_output(0, data.shape(), &grad_data)); + OP_REQUIRES_OK(ctx, ctx->allocate_output(1, warp.shape(), &grad_warp)); + // Execute kernel only for nonempty output; otherwise Eigen crashes on GPU. + if (num_sampling_points > 0) { + functor::ResamplerGrad2DFunctor<Device, T>()(ctx, + ctx->eigen_device<Device>(), + data.flat<T>().data(), + warp.flat<T>().data(), + grad_output.flat<T>().data(), + grad_data->flat<T>().data(), + grad_warp->flat<T>().data(), + batch_size, + data_height, + data_width, + data_channels, + num_sampling_points); + } + } + + private: + TF_DISALLOW_COPY_AND_ASSIGN(ResamplerGradOp); +}; + +#define REGISTER(TYPE) \ + REGISTER_KERNEL_BUILDER( \ + Name("ResamplerGrad") \ + .Device(DEVICE_CPU) \ + .TypeConstraint<TYPE>("T"), \ + ResamplerGradOp<CPUDevice, TYPE>); + +TF_CALL_half(REGISTER); +TF_CALL_float(REGISTER); +TF_CALL_double(REGISTER); +#undef REGISTER + +#if GOOGLE_CUDA +#define REGISTER(TYPE) \ + REGISTER_KERNEL_BUILDER(Name("ResamplerGrad") \ + .Device(DEVICE_GPU) \ + .TypeConstraint<TYPE>("T"), \ + ResamplerGradOp<GPUDevice, TYPE>) +// Disable half and double precision since atomicAdds are not supported +// TF_CALL_half(REGISTER); +// TF_CALL_double(REGISTER); +TF_CALL_float(REGISTER); + +#undef REGISTER +#endif // GOOGLE_CUDA + +} // namespace tensorflow diff --git a/tensorflow/contrib/resampler/kernels/resampler_ops.h b/tensorflow/contrib/resampler/kernels/resampler_ops.h new file mode 100644 index 0000000000..8258ecaf5d --- /dev/null +++ b/tensorflow/contrib/resampler/kernels/resampler_ops.h @@ -0,0 +1,68 @@ +// Copyright 2017 The Sonnet Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +// ============================================================================= + +#ifndef THIRD_PARTY_TENSORFLOW_CONTRIB_RESAMPLER_KERNELS_RESAMPLER_OPS_H_ +#define THIRD_PARTY_TENSORFLOW_CONTRIB_RESAMPLER_KERNELS_RESAMPLER_OPS_H_ + +#if PLATFORM_WINDOWS +#define __restrict__ __restrict +#endif + +namespace tensorflow { +class OpKernelContext; +} + +namespace tensorflow { +namespace functor { + +// Helper functor for the Resampler Op in 2D +template <typename Device, typename T> +struct Resampler2DFunctor{ + void operator ()(::tensorflow::OpKernelContext* ctx, + const Device& d, + const T* __restrict__ data, + const T* __restrict__ warp, + T* __restrict__ output, + const int batch_size, + const int data_height, + const int data_width, + const int data_channels, + const int num_sampling_points); +}; + + +// Helper functor for the Resampler Gradient Op in 2D +template <typename Device, typename T> +struct ResamplerGrad2DFunctor{ + void operator ()(::tensorflow::OpKernelContext* ctx, + const Device& d, + const T* __restrict__ data, + const T* __restrict__ warp, + const T* __restrict__ grad_output, + T* __restrict__ grad_data, + T* __restrict__ grad_warp, + const int batch_size, + const int data_height, + const int data_width, + const int data_channels, + const int num_sampling_points); +}; + + +} // namespace functor +} // namespace tensorflow + + +#endif // THIRD_PARTY_TENSORFLOW_CONTRIB_RESAMPLER_KERNELS_RESAMPLER_OPS_H_ diff --git a/tensorflow/contrib/resampler/kernels/resampler_ops_gpu.cu.cc b/tensorflow/contrib/resampler/kernels/resampler_ops_gpu.cu.cc new file mode 100644 index 0000000000..636847a212 --- /dev/null +++ b/tensorflow/contrib/resampler/kernels/resampler_ops_gpu.cu.cc @@ -0,0 +1,310 @@ +// Copyright 2016 The Sonnet Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +// ============================================================================= + +#if GOOGLE_CUDA + +#define EIGEN_USE_GPU + +#include "tensorflow/contrib/resampler/kernels/resampler_ops.h" + +#include <stdio.h> +#include <cmath> + +#include "tensorflow/core/framework/register_types.h" +#include "tensorflow/core/util/cuda_kernel_helper.h" + +namespace tensorflow { + +using GPUDevice = Eigen::GpuDevice; + +namespace { + +#define GET_DATA_POINT(x, y) \ + data[batch_id * data_batch_stride + \ + data_channels * (y * data_width + x) + \ + chan] + +template <typename T> +__global__ void Resampler2DKernel(const T* __restrict__ data, + const T* __restrict__ warp, + T* __restrict__ output, + const int batch_size, + const int data_height, + const int data_width, + const int data_channels, + const int num_sampling_points) { + const int output_data_size = batch_size * num_sampling_points * data_channels; + CUDA_1D_KERNEL_LOOP(index, output_data_size) { + const int out_index = index; + + // Get (idxSample, channel, point) from the index. + // Use this formula + // index = batch_id * num_sampling_points * num_chans + + // sample_id * num_chans + chan_id, + // with sample_id = [0, ... ,num_sampling_points) + const int data_batch_stride = data_height * data_width * data_channels; + const int warp_batch_stride = num_sampling_points * 2; + const int output_batch_stride = num_sampling_points * data_channels; + + const int batch_id = index / output_batch_stride; + const int index_in_batch = index % output_batch_stride; + const int chan = index_in_batch % data_channels; + const int sample_id = index_in_batch / data_channels; + + // Get coords of 2D point where data will be resampled + const T x = warp[batch_id * warp_batch_stride + sample_id * 2]; + const T y = warp[batch_id * warp_batch_stride + sample_id * 2 + 1]; + const T zero = static_cast<T>(0.0); + const T one = static_cast<T>(1.0); + // The interpolation function: + // a) implicitly pads the input data with 0s (hence the unusual checks + // with {x,y} > -1) + // b) returns 0 when sampling outside the (padded) image. + // The effect is that the sampled signal smoothly goes to 0 outside + // the original input domain, rather than presenting a jump + // discontinuity at the image boundaries. + if (x > static_cast<T>(-1.0) && + y > static_cast<T>(-1.0) && + x < static_cast<T>(data_width) && + y < static_cast<T>(data_height)) { + // Precompute floor (f) and ceil (c) values for x and y. + const int fx = std::floor(static_cast<float>(x)); + const int fy = std::floor(static_cast<float>(y)); + const int cx = fx + 1; + const int cy = fy + 1; + const T dx = static_cast<T>(cx) - x; + const T dy = static_cast<T>(cy) - y; + + const T img_fxfy = (fx >= 0 && fy >= 0) + ? dx * dy * GET_DATA_POINT(fx, fy) + : zero; + + const T img_cxcy = (cx <= data_width - 1 && cy <= data_height - 1) + ? (one - dx) * (one - dy) * GET_DATA_POINT(cx, cy) + : zero; + + const T img_fxcy = (fx >= 0 && cy <= data_height - 1) + ? dx * (one - dy) * GET_DATA_POINT(fx, cy) + : zero; + + const T img_cxfy = (cx <= data_width - 1 && fy >= 0) + ? (one - dx) * dy * GET_DATA_POINT(cx, fy) + : zero; + + output[out_index] = img_fxfy + img_cxcy + img_fxcy + img_cxfy; + } else { + output[out_index] = zero; + } + } +} + +} // namespace + +namespace functor { + +template <typename T> +struct Resampler2DFunctor<GPUDevice, T>{ + void operator ()(::tensorflow::OpKernelContext* ctx, + const GPUDevice& d, + const T* __restrict__ data, + const T* __restrict__ warp, + T* __restrict__ output, + const int batch_size, + const int data_height, + const int data_width, + const int data_channels, + const int num_sampling_points) { + const int output_data_size = batch_size * num_sampling_points * data_channels; + ::tensorflow::CudaLaunchConfig config = + ::tensorflow::GetCudaLaunchConfig(output_data_size, d); + Resampler2DKernel<T> + <<<config.block_count, config.thread_per_block, 0, d.stream()>>>( + data, warp, output, batch_size, data_height, data_width, + data_channels, num_sampling_points); + } +}; + +// TODO(fviola): gcudacc fails at compile time with Eigen::half. +// template struct Resampler2DFunctor<GPUDevice, Eigen::half>; +template struct Resampler2DFunctor<GPUDevice, float>; +template struct Resampler2DFunctor<GPUDevice, double>; + +} // namespace functor + +namespace { + +#define UPDATE_GRAD_DATA_POINT(x, y, v) \ + atomicAdd(grad_data + (batch_id * data_batch_stride + \ + data_channels * (y * data_width + x) + \ + chan), \ + v) + + +template <typename T> +__global__ void ResamplerGrad2DKernel(const T* __restrict__ data, + const T* __restrict__ warp, + const T* __restrict__ grad_output, + T* __restrict__ grad_data, + T* __restrict__ grad_warp, + const int batch_size, + const int data_height, + const int data_width, + const int data_channels, + const int num_sampling_points) { + const int resampler_output_size = batch_size * num_sampling_points * + data_channels; + CUDA_1D_KERNEL_LOOP(index, resampler_output_size) { + const int out_index = index; + + // Get (idxSample, channel, point) from the index. + // Use this formula + // index = batch_id * num_sampling_points * num_chans + + // sample_id * num_chans + chan_id, + // with sample_id = [0, ... ,num_sampling_points) + const int data_batch_stride = data_height * data_width * data_channels; + const int warp_batch_stride = num_sampling_points * 2; + const int output_batch_stride = num_sampling_points * data_channels; + + const int batch_id = index / output_batch_stride; + const int index_in_batch = index % output_batch_stride; + const int chan = index_in_batch % data_channels; + const int sample_id = index_in_batch / data_channels; + + // Get coords of 2D point where data will be resampled + const int warp_id_x = batch_id * warp_batch_stride + sample_id * 2; + const int warp_id_y = warp_id_x + 1; + const T x = warp[warp_id_x]; + const T y = warp[warp_id_y]; + const T zero = static_cast<T>(0.0); + const T one = static_cast<T>(1.0); + + // Get grad output + const T grad_output_value = grad_output[out_index]; + // The interpolation function whose gradient this kernel implements: + // a) implicitly pads the input data with 0s (hence the unusual checks + // with {x,y} > -1) + // b) returns 0 when sampling outside the (padded) image. + // The effect is that the sampled signal smoothly goes to 0 outside + // the original input domain, rather than presenting a jump + // discontinuity at the image boundaries. + if (x > static_cast<T>(-1.0) && + y > static_cast<T>(-1.0) && + x < static_cast<T>(data_width) && + y < static_cast<T>(data_height)) { + // Precompute floor (f) and ceil (c) values for x and y. + const int fx = std::floor(static_cast<float>(x)); + const int fy = std::floor(static_cast<float>(y)); + const int cx = fx + 1; + const int cy = fy + 1; + const T dx = static_cast<T>(cx) - x; + const T dy = static_cast<T>(cy) - y; + + const T img_fxfy = (fx >= 0 && fy >= 0) + ? GET_DATA_POINT(fx, fy) + : zero; + + const T img_cxcy = (cx <= data_width - 1 && cy <= data_height - 1) + ? GET_DATA_POINT(cx, cy) + : zero; + + const T img_fxcy = (fx >= 0 && cy <= data_height - 1) + ? GET_DATA_POINT(fx, cy) + : zero; + + const T img_cxfy = (cx <= data_width - 1 && fy >= 0) + ? GET_DATA_POINT(cx, fy) + : zero; + + // Update partial gradients wrt relevant warp field entries + atomicAdd(grad_warp + warp_id_x, + grad_output_value * ((one - dy) * (img_cxcy - img_fxcy) + + dy * (img_cxfy - img_fxfy))); + atomicAdd(grad_warp + warp_id_y, + grad_output_value * ((one - dx) * (img_cxcy - img_cxfy) + + dx * (img_fxcy - img_fxfy))); + + // Update partial gradients wrt sampled data + if (fx >= 0 && fy >= 0) { + UPDATE_GRAD_DATA_POINT(fx, fy, grad_output_value * dx * dy); + } + if (cx <= data_width - 1 && cy <= data_height - 1) { + UPDATE_GRAD_DATA_POINT(cx, cy, + grad_output_value * (one - dx) * (one - dy)); + } + if (fx >= 0 && cy <= data_height - 1) { + UPDATE_GRAD_DATA_POINT(fx, cy, grad_output_value * dx * (one - dy)); + } + if (cx <= data_width - 1 && fy >= 0) { + UPDATE_GRAD_DATA_POINT(cx, fy, grad_output_value * (one - dx) * dy); + } + } + } +} + +#undef GET_DATA_POINT +#undef UPDATE_GRAD_DATA_POINT + +} // namespace + +namespace functor { + +template <typename T> +struct ResamplerGrad2DFunctor<GPUDevice, T>{ + void operator ()(::tensorflow::OpKernelContext* ctx, + const GPUDevice& d, + const T* __restrict__ data, + const T* __restrict__ warp, + const T* __restrict__ grad_output, + T* __restrict__ grad_data, + T* __restrict__ grad_warp, + const int batch_size, + const int data_height, + const int data_width, + const int data_channels, + const int num_sampling_points) { + // Set gradients to 0, because the kernel incrementally updates the + // tensor entries by adding partial contributions. + const int grad_warp_size = batch_size * num_sampling_points * 2; + const int grad_data_size = batch_size * data_height * data_width * + data_channels; + + ::tensorflow::CudaLaunchConfig config = + ::tensorflow::GetCudaLaunchConfig(grad_warp_size, d); + ::tensorflow::SetZero + <<<config.block_count, config.thread_per_block, 0, d.stream()>>>( + grad_warp_size, grad_warp); + + config = ::tensorflow::GetCudaLaunchConfig(grad_data_size, d); + ::tensorflow::SetZero + <<<config.block_count, config.thread_per_block, 0, d.stream()>>>( + grad_data_size, grad_data); + + const int resampler_output_size = batch_size * num_sampling_points * + data_channels; + config = ::tensorflow::GetCudaLaunchConfig(resampler_output_size, d); + ResamplerGrad2DKernel<T> + <<<config.block_count, config.thread_per_block, 0, d.stream()>>>( + data, warp, grad_output, grad_data, grad_warp, batch_size, + data_height, data_width, data_channels, num_sampling_points); + } +}; + +template struct ResamplerGrad2DFunctor<GPUDevice, float>; + +} // namespace functor + +} // namespace tensorflow + +#endif // GOOGLE_CUDA diff --git a/tensorflow/contrib/resampler/ops/resampler_ops.cc b/tensorflow/contrib/resampler/ops/resampler_ops.cc new file mode 100644 index 0000000000..5ab212032e --- /dev/null +++ b/tensorflow/contrib/resampler/ops/resampler_ops.cc @@ -0,0 +1,59 @@ +// Copyright 2017 The Sonnet Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +// ============================================================================= + +#include "tensorflow/core/framework/op.h" +#include "tensorflow/core/framework/shape_inference.h" + +namespace tensorflow { + +using ::tensorflow::shape_inference::InferenceContext; +using ::tensorflow::shape_inference::ShapeHandle; + +REGISTER_OP("Resampler") + .Input("data: T") + .Input("warp: T") + .Output("output: T") + .Attr("T: {half, float, double}") + .SetShapeFn([](InferenceContext* c) { + ShapeHandle data; + ShapeHandle warp; + TF_RETURN_IF_ERROR(c->WithRankAtLeast(c->input(0), 1, &data)); + TF_RETURN_IF_ERROR(c->WithRankAtLeast(c->input(1), 1, &warp)); + + ShapeHandle output; // will be warp[:-1] + [data[-1]] + TF_RETURN_IF_ERROR(c->Subshape(warp, 0, -1, &output)); + TF_RETURN_IF_ERROR( + c->Concatenate(output, c->Vector(c->Dim(data, -1)), &output)); + + c->set_output(0, output); + return ::tensorflow::Status::OK(); + }) + .Doc(R"doc(Resampler op.)doc"); + +REGISTER_OP("ResamplerGrad") + .Input("data: T") + .Input("warp: T") + .Input("grad_output: T") + .Output("grad_data: T") + .Output("grad_warp: T") + .Attr("T: {half, float, double}") + .SetShapeFn([](InferenceContext* c) { + c->set_output(0, c->input(0)); + c->set_output(1, c->input(1)); + return ::tensorflow::Status::OK(); + }) + .Doc(R"doc(Resampler Grad op.)doc"); + +} // namespace tensorflow diff --git a/tensorflow/contrib/resampler/python/__init__.py b/tensorflow/contrib/resampler/python/__init__.py new file mode 100644 index 0000000000..c5ca3a623f --- /dev/null +++ b/tensorflow/contrib/resampler/python/__init__.py @@ -0,0 +1,19 @@ +# Copyright 2016 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""ops module.""" + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function diff --git a/tensorflow/contrib/resampler/python/ops/resampler_ops.py b/tensorflow/contrib/resampler/python/ops/resampler_ops.py new file mode 100644 index 0000000000..355d15f0c7 --- /dev/null +++ b/tensorflow/contrib/resampler/python/ops/resampler_ops.py @@ -0,0 +1,69 @@ +# pylint: disable=g-bad-file-header +# Copyright 2017 The Sonnet Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ + +"""Tensorflow op performing differentiable resampling.""" + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +from tensorflow.contrib.resampler.ops import gen_resampler_ops +from tensorflow.contrib.util import loader +from tensorflow.python.framework import ops +from tensorflow.python.platform import resource_loader + +_resampler_so = loader.load_op_library( + resource_loader.get_path_to_datafile("_resampler_ops.so")) + + +def resampler(data, warp, name="resampler"): + """Resamples input data at user defined coordinates. + + The resampler currently only supports bilinear interpolation of 2D data. + + Args: + data: Tensor of shape `[batch_size, data_height, data_width, + data_num_channels]` containing 2D data that will be resampled. + warp: Tensor of minimum rank 2 containing the coordinates at which + resampling will be performed. Since only bilinear interpolation is + currently supported, the last dimension of the `warp` tensor must be 2. + name: Optional name of the op. + + Returns: + Tensor of resampled values from `data`. The output tensor shape is + determined by the shape of the warp tensor. For example, if `data` is of + shape `[batch_size, data_height, data_width, data_num_channels]` and warp of + shape `[batch_size, dim_0, ... , dim_n, 2]` the output will be of shape + `[batch_size, dim_0, ... , dim_n, data_num_channels]`. + + Raises: + ImportError: if the wrapper generated during compilation is not present when + the function is called. + """ + with ops.name_scope(name, "resampler", [data, warp]): + data_tensor = ops.convert_to_tensor(data, name="data") + warp_tensor = ops.convert_to_tensor(warp, name="warp") + return gen_resampler_ops.resampler(data_tensor, warp_tensor) + + +@ops.RegisterGradient("Resampler") +def _resampler_grad(op, grad_output): + data, warp = op.inputs + grad_output_tensor = ops.convert_to_tensor(grad_output, name="grad_output") + return gen_resampler_ops.resampler_grad(data, warp, grad_output_tensor) + + +ops.NotDifferentiable("ResamplerGrad") diff --git a/tensorflow/contrib/resampler/python/ops/resampler_ops_test.py b/tensorflow/contrib/resampler/python/ops/resampler_ops_test.py new file mode 100644 index 0000000000..6a4360150c --- /dev/null +++ b/tensorflow/contrib/resampler/python/ops/resampler_ops_test.py @@ -0,0 +1,270 @@ +# pylint: disable=g-bad-file-header +# Copyright 2017 The Sonnet Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ + +"""Tests for contrib.resampler.python.ops.resampler_ops.""" +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import numpy as np +from six.moves import xrange # pylint: disable=redefined-builtin + +from tensorflow.contrib import resampler +from tensorflow.python.framework import constant_op +from tensorflow.python.framework import dtypes +from tensorflow.python.framework import errors_impl +from tensorflow.python.ops import array_ops +from tensorflow.python.platform import test + + +def _bilinearly_interpolate(data, x, y): + """Performs bilinenar interpolation of grid data at user defined coordinates. + + This interpolation function: + a) implicitly pads the input data with 0s. + b) returns 0 when sampling outside the (padded) image. + The effect is that the sampled signal smoothly goes to 0 outside the original + input domain, rather than producing a jump discontinuity at the image + boundaries. + + Args: + data: numpy array of shape `[data_height, data_width]` containing data + samples assumed to be defined at the corresponding pixel coordinates. + x: numpy array of shape `[warp_height, warp_width]` containing x coordinates + at which interpolation will be performed. + y: numpy array of shape `[warp_height, warp_width]` containing y coordinates + at which interpolation will be performed. + + Returns: + Numpy array of shape `[warp_height, warp_width]` containing interpolated + values. + """ + shape = x.shape + x = np.asarray(x) + 1 + y = np.asarray(y) + 1 + data = np.lib.pad(data, 1, "constant", constant_values=0) + + x_0 = np.floor(x).astype(int) + x_1 = x_0 + 1 + y_0 = np.floor(y).astype(int) + y_1 = y_0 + 1 + + x_0 = np.clip(x_0, 0, data.shape[1] - 1) + x_1 = np.clip(x_1, 0, data.shape[1] - 1) + y_0 = np.clip(y_0, 0, data.shape[0] - 1) + y_1 = np.clip(y_1, 0, data.shape[0] - 1) + + i_a = data[y_0, x_0] + i_b = data[y_1, x_0] + i_c = data[y_0, x_1] + i_d = data[y_1, x_1] + + w_a = (x_1 - x) * (y_1 - y) + w_b = (x_1 - x) * (y - y_0) + w_c = (x - x_0) * (y_1 - y) + w_d = (x - x_0) * (y - y_0) + + samples = (w_a * i_a + w_b * i_b + w_c * i_c + w_d * i_d) + samples.reshape(shape) + + return samples + + +def _make_warp(batch_size, warp_height, warp_width, dtype): + """Creates batch of warping coordinates.""" + x, y = np.meshgrid(np.linspace(0, warp_width - 1, warp_width), + np.linspace(0, warp_height - 1, warp_height)) + warp = np.concatenate((x.reshape([warp_height, warp_width, 1]), + y.reshape([warp_height, warp_width, 1])), 2) + warp = np.tile(warp.reshape([1, warp_height, warp_width, 2]), + [batch_size, 1, 1, 1]) + warp += np.random.randn(*warp.shape) + return warp.astype(dtype) + + +class ResamplerTest(test.TestCase): + + def test_op_forward_pass_gpu_float32(self): + self._test_op_forward_pass(True, dtypes.float32, 1e-4) + + def test_op_forward_pass_gpu_float64(self): + self._test_op_forward_pass(True, dtypes.float64, 1e-5) + + def test_op_forward_pass_cpu_float16(self): + self._test_op_forward_pass(False, dtypes.float16, 1e-2) + + def test_op_forward_pass_cpu_float32(self): + self._test_op_forward_pass(False, dtypes.float32, 1e-4) + + def test_op_forward_pass_cpu_float64(self): + self._test_op_forward_pass(False, dtypes.float64, 1e-5) + + def test_op_backward_pass_gpu_float32(self): + self._test_op_backward_pass(True, dtypes.float32, 1e-3) + + def test_op_backward_pass_cpu_float16(self): + self._test_op_backward_pass(False, dtypes.float16, 1e-3) + + def test_op_backward_pass_cpu_float32(self): + self._test_op_backward_pass(False, dtypes.float32, 1e-4) + + def test_op_backward_pass_cpu_float64(self): + self._test_op_backward_pass(False, dtypes.float64, 1e-6) + + def _test_op_forward_pass(self, on_gpu, dtype, tol): + np.random.seed(0) + data_width = 7 + data_height = 9 + data_channels = 5 + warp_width = 4 + warp_height = 8 + batch_size = 10 + + warp = _make_warp(batch_size, warp_height, warp_width, dtype.as_numpy_dtype) + data_shape = (batch_size, data_height, data_width, data_channels) + data = np.random.rand(*data_shape).astype(dtype.as_numpy_dtype) + + with self.test_session(use_gpu=on_gpu, force_gpu=False) as sess: + data_ph = array_ops.placeholder(dtype, shape=(None,) + data.shape[1:]) + warp_ph = array_ops.placeholder(dtype, shape=(None,) + warp.shape[1:]) + outputs = resampler.resampler(data=data_ph, warp=warp_ph) + self.assertEqual(outputs.get_shape().as_list(), + [None, warp_height, warp_width, data_channels]) + out = sess.run(outputs, feed_dict={data_ph: data, warp_ph: warp}) + + # Generate reference output via bilinear interpolation in numpy + reference_output = np.zeros_like(out) + for batch in xrange(batch_size): + for c in xrange(data_channels): + reference_output[batch, :, :, c] = _bilinearly_interpolate( + data[batch, :, :, c], + warp[batch, :, :, 0], + warp[batch, :, :, 1]) + + self.assertAllClose(out, reference_output, rtol=tol, atol=tol) + + def _test_op_backward_pass(self, on_gpu, dtype, tol): + np.random.seed(13) + data_width = 5 + data_height = 4 + data_channels = 3 + warp_width = 2 + warp_height = 6 + batch_size = 10 + + warp = _make_warp(batch_size, warp_height, warp_width, dtype.as_numpy_dtype) + data_shape = (batch_size, data_height, data_width, data_channels) + data = np.random.rand(*data_shape).astype(dtype.as_numpy_dtype) + + with self.test_session(use_gpu=on_gpu, force_gpu=False): + data_tensor = constant_op.constant(data) + warp_tensor = constant_op.constant(warp) + output_tensor = resampler.resampler(data=data_tensor, warp=warp_tensor) + + grads = test.compute_gradient([data_tensor, warp_tensor], [ + data_tensor.get_shape().as_list(), + warp_tensor.get_shape().as_list() + ], output_tensor, output_tensor.get_shape().as_list(), [data, warp]) + + if not on_gpu: + # On CPU we perform numerical differentiation at the best available + # precision, and compare against that. This is necessary for test to + # pass for float16. + data_tensor_64 = constant_op.constant(data, dtype=dtypes.float64) + warp_tensor_64 = constant_op.constant(warp, dtype=dtypes.float64) + output_tensor_64 = resampler.resampler(data=data_tensor_64, + warp=warp_tensor_64) + grads_64 = test.compute_gradient([data_tensor_64, warp_tensor_64], [ + data_tensor.get_shape().as_list(), + warp_tensor.get_shape().as_list() + ], output_tensor_64, output_tensor.get_shape().as_list(), [data, warp]) + + for g, g_64 in zip(grads, grads_64): + self.assertLess(np.fabs(g[0] - g_64[1]).max(), tol) + + else: + for g in grads: + self.assertLess(np.fabs(g[0] - g[1]).max(), tol) + + def test_op_errors(self): + data_width = 7 + data_height = 9 + data_depth = 3 + data_channels = 5 + warp_width = 4 + warp_height = 8 + batch_size = 10 + + # Input data shape is not defined over a 2D grid, i.e. its shape is not like + # (batch_size, data_height, data_width, data_channels). + with self.test_session() as sess: + data_shape = (batch_size, data_height, data_width, data_depth, + data_channels) + data = np.zeros(data_shape) + warp_shape = (batch_size, warp_height, warp_width, 2) + warp = np.zeros(warp_shape) + outputs = resampler.resampler(constant_op.constant(data), + constant_op.constant(warp)) + + with self.assertRaisesRegexp(errors_impl.UnimplementedError, + "Only bilinear interpolation is currently " + "supported."): + sess.run(outputs) + + # Warp tensor must be at least a matrix, with shape [batch_size, 2]. + with self.test_session() as sess: + data_shape = (batch_size, data_height, data_width, data_channels) + data = np.zeros(data_shape) + warp_shape = (batch_size,) + warp = np.zeros(warp_shape) + outputs = resampler.resampler(constant_op.constant(data), + constant_op.constant(warp)) + + with self.assertRaisesRegexp(errors_impl.InvalidArgumentError, + "warp should be at least a matrix"): + sess.run(outputs) + + # The batch size of the data and warp tensors must be the same. + with self.test_session() as sess: + data_shape = (batch_size, data_height, data_width, data_channels) + data = np.zeros(data_shape) + warp_shape = (batch_size+1, warp_height, warp_width, 2) + warp = np.zeros(warp_shape) + outputs = resampler.resampler(constant_op.constant(data), + constant_op.constant(warp)) + + with self.assertRaisesRegexp(errors_impl.InvalidArgumentError, + "Batch size of data and warp tensor"): + sess.run(outputs) + + # The warp tensor must contain 2D coordinates, i.e. its shape last dimension + # must be 2. + with self.test_session() as sess: + data_shape = (batch_size, data_height, data_width, data_channels) + data = np.zeros(data_shape) + warp_shape = (batch_size, warp_height, warp_width, 3) + warp = np.zeros(warp_shape) + outputs = resampler.resampler(constant_op.constant(data), + constant_op.constant(warp)) + + with self.assertRaisesRegexp(errors_impl.UnimplementedError, + "Only bilinear interpolation is supported, " + "warping"): + sess.run(outputs) + + +if __name__ == "__main__": + test.main() |