aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/contrib/image
diff options
context:
space:
mode:
authorGravatar A. Unique TensorFlower <gardener@tensorflow.org>2018-03-14 14:47:40 -0700
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2018-03-14 14:51:50 -0700
commit5ac329bd86e400d47155e0c890669f4ee688771d (patch)
tree4670d471082272ef0cf8474ae86a5d903f8084d5 /tensorflow/contrib/image
parentac8ce1fe760efff6585d790b784ec67255198879 (diff)
Adding non-linear image warping ops to tf.contrib.image
New ops are: tf.contrib.image.sparse_image_warp, tf.contrib.image.dense_image_warp, and tf.contrib.image.interpolate_spline. PiperOrigin-RevId: 189089672
Diffstat (limited to 'tensorflow/contrib/image')
-rwxr-xr-xtensorflow/contrib/image/BUILD109
-rwxr-xr-xtensorflow/contrib/image/__init__.py7
-rw-r--r--tensorflow/contrib/image/python/kernel_tests/dense_image_warp_test.py264
-rw-r--r--tensorflow/contrib/image/python/kernel_tests/interpolate_spline_test.py261
-rw-r--r--tensorflow/contrib/image/python/kernel_tests/sparse_image_warp_test.py251
-rw-r--r--tensorflow/contrib/image/python/kernel_tests/test_data/Yellow_Smiley_Face.pngbin0 -> 14060 bytes
-rw-r--r--tensorflow/contrib/image/python/kernel_tests/test_data/Yellow_Smiley_Face_Warp-interp-1-clamp-0.pngbin0 -> 18537 bytes
-rw-r--r--tensorflow/contrib/image/python/kernel_tests/test_data/Yellow_Smiley_Face_Warp-interp-1-clamp-1.pngbin0 -> 19086 bytes
-rw-r--r--tensorflow/contrib/image/python/kernel_tests/test_data/Yellow_Smiley_Face_Warp-interp-1-clamp-4.pngbin0 -> 18884 bytes
-rw-r--r--tensorflow/contrib/image/python/kernel_tests/test_data/Yellow_Smiley_Face_Warp-interp-2-clamp-0.pngbin0 -> 18109 bytes
-rw-r--r--tensorflow/contrib/image/python/kernel_tests/test_data/Yellow_Smiley_Face_Warp-interp-2-clamp-1.pngbin0 -> 19251 bytes
-rw-r--r--tensorflow/contrib/image/python/kernel_tests/test_data/Yellow_Smiley_Face_Warp-interp-2-clamp-4.pngbin0 -> 19132 bytes
-rw-r--r--tensorflow/contrib/image/python/kernel_tests/test_data/Yellow_Smiley_Face_Warp-interp-3-clamp-0.pngbin0 -> 17500 bytes
-rw-r--r--tensorflow/contrib/image/python/kernel_tests/test_data/Yellow_Smiley_Face_Warp-interp-3-clamp-1.pngbin0 -> 18058 bytes
-rw-r--r--tensorflow/contrib/image/python/kernel_tests/test_data/Yellow_Smiley_Face_Warp-interp-3-clamp-4.pngbin0 -> 19313 bytes
-rw-r--r--tensorflow/contrib/image/python/ops/dense_image_warp.py196
-rw-r--r--tensorflow/contrib/image/python/ops/interpolate_spline.py285
-rw-r--r--tensorflow/contrib/image/python/ops/sparse_image_warp.py192
18 files changed, 1565 insertions, 0 deletions
diff --git a/tensorflow/contrib/image/BUILD b/tensorflow/contrib/image/BUILD
index 3ff02e085e..760ed70fbb 100755
--- a/tensorflow/contrib/image/BUILD
+++ b/tensorflow/contrib/image/BUILD
@@ -78,7 +78,10 @@ tf_custom_op_py_library(
],
srcs_version = "PY2AND3",
deps = [
+ ":dense_image_warp_py",
":image_ops",
+ ":interpolate_spline_py",
+ ":sparse_image_warp_py",
"//tensorflow/contrib/util:util_py",
"//tensorflow/python:array_ops",
"//tensorflow/python:common_shapes",
@@ -194,6 +197,112 @@ cuda_py_test(
],
)
+py_library(
+ name = "dense_image_warp_py",
+ srcs = [
+ "python/ops/dense_image_warp.py",
+ ],
+ srcs_version = "PY2AND3",
+ deps = [
+ "//tensorflow/contrib/util:util_py",
+ "//tensorflow/python:platform",
+ "//tensorflow/python:util",
+ "//third_party/py/numpy",
+ ],
+)
+
+py_library(
+ name = "interpolate_spline_py",
+ srcs = [
+ "python/ops/interpolate_spline.py",
+ ],
+ srcs_version = "PY2AND3",
+ deps = [
+ "//tensorflow/contrib/util:util_py",
+ "//tensorflow/python:platform",
+ "//tensorflow/python:util",
+ ],
+)
+
+py_library(
+ name = "sparse_image_warp_py",
+ srcs = [
+ "python/ops/sparse_image_warp.py",
+ ],
+ srcs_version = "PY2AND3",
+ deps = [
+ ":dense_image_warp_py",
+ ":interpolate_spline_py",
+ "//tensorflow/contrib/util:util_py",
+ "//tensorflow/python:platform",
+ "//tensorflow/python:util",
+ ],
+)
+
+cuda_py_test(
+ name = "sparse_image_warp_test",
+ size = "medium",
+ srcs = ["python/kernel_tests/sparse_image_warp_test.py"],
+ additional_deps = [
+ ":sparse_image_warp_py",
+ "//third_party/py/numpy",
+ "//tensorflow/python:client",
+ "//tensorflow/python:client_testlib",
+ "//tensorflow/python:framework_test_lib",
+ "//tensorflow/python:math_ops",
+ "//tensorflow/python:clip_ops",
+ "//tensorflow/python:io_ops",
+ "//tensorflow/python:platform_test",
+ "//tensorflow/python:random_ops",
+ "//tensorflow/python:image_ops",
+ "//tensorflow/python:variables",
+ "//tensorflow/core:protos_all_py",
+ ],
+ data = glob(["python/kernel_tests/test_data/*.png"]),
+)
+
+cuda_py_test(
+ name = "dense_image_warp_test",
+ size = "medium",
+ srcs = ["python/kernel_tests/dense_image_warp_test.py"],
+ additional_deps = [
+ ":dense_image_warp_py",
+ "//third_party/py/numpy",
+ "//tensorflow/python:client",
+ "//tensorflow/python:client_testlib",
+ "//tensorflow/python:framework_test_lib",
+ "//tensorflow/python:math_ops",
+ "//tensorflow/python:clip_ops",
+ "//tensorflow/python:io_ops",
+ "//tensorflow/python:platform_test",
+ "//tensorflow/python:random_ops",
+ "//tensorflow/python:image_ops",
+ "//tensorflow/python:variables",
+ "//tensorflow/core:protos_all_py",
+ ],
+)
+
+cuda_py_test(
+ name = "interpolate_spline_test",
+ size = "medium",
+ srcs = ["python/kernel_tests/interpolate_spline_test.py"],
+ additional_deps = [
+ ":interpolate_spline_py",
+ "//third_party/py/numpy",
+ "//tensorflow/python:client",
+ "//tensorflow/python:client_testlib",
+ "//tensorflow/python:framework_test_lib",
+ "//tensorflow/python:math_ops",
+ "//tensorflow/python:clip_ops",
+ "//tensorflow/python:io_ops",
+ "//tensorflow/python:platform_test",
+ "//tensorflow/python:image_ops",
+ "//tensorflow/python:variables",
+ "//tensorflow/core:protos_all_py",
+ "//third_party/py/scipy",
+ ],
+)
+
tf_py_test(
name = "segmentation_test",
size = "medium",
diff --git a/tensorflow/contrib/image/__init__.py b/tensorflow/contrib/image/__init__.py
index cc8ed117ba..e982030bc8 100755
--- a/tensorflow/contrib/image/__init__.py
+++ b/tensorflow/contrib/image/__init__.py
@@ -30,6 +30,9 @@ projective transforms (including rotation) are supported.
@@transform
@@translate
@@translations_to_projective_transforms
+@@dense_image_warp
+@@interpolate_spline
+@@sparse_image_warp
## Image Segmentation `Ops`
@@ -47,6 +50,8 @@ from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
+from tensorflow.contrib.image.python.ops.dense_image_warp import dense_image_warp
+
from tensorflow.contrib.image.python.ops.distort_image_ops import adjust_hsv_in_yiq
from tensorflow.contrib.image.python.ops.distort_image_ops import random_hsv_in_yiq
@@ -57,7 +62,9 @@ from tensorflow.contrib.image.python.ops.image_ops import rotate
from tensorflow.contrib.image.python.ops.image_ops import transform
from tensorflow.contrib.image.python.ops.image_ops import translate
from tensorflow.contrib.image.python.ops.image_ops import translations_to_projective_transforms
+from tensorflow.contrib.image.python.ops.interpolate_spline import interpolate_spline
from tensorflow.contrib.image.python.ops.single_image_random_dot_stereograms import single_image_random_dot_stereograms
+from tensorflow.contrib.image.python.ops.sparse_image_warp import sparse_image_warp
from tensorflow.python.util.all_util import remove_undocumented
diff --git a/tensorflow/contrib/image/python/kernel_tests/dense_image_warp_test.py b/tensorflow/contrib/image/python/kernel_tests/dense_image_warp_test.py
new file mode 100644
index 0000000000..24d99ccaa6
--- /dev/null
+++ b/tensorflow/contrib/image/python/kernel_tests/dense_image_warp_test.py
@@ -0,0 +1,264 @@
+# Copyright 2018 The TensorFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+"""Tests for dense_image_warp."""
+
+import math
+import numpy as np
+
+from tensorflow.contrib.image.python.ops import dense_image_warp
+
+from tensorflow.python.framework import constant_op
+from tensorflow.python.framework import dtypes
+
+from tensorflow.python.framework import test_util
+from tensorflow.python.ops import array_ops
+from tensorflow.python.ops import gradients
+from tensorflow.python.ops import math_ops
+from tensorflow.python.ops import random_ops
+from tensorflow.python.ops import variables
+from tensorflow.python.platform import googletest
+
+from tensorflow.python.training import adam
+
+
+class DenseImageWarpTest(test_util.TensorFlowTestCase):
+
+ def setUp(self):
+ np.random.seed(0)
+
+ def test_interpolate_small_grid_ij(self):
+ grid = constant_op.constant(
+ [[0., 1., 2.], [3., 4., 5.], [6., 7., 8.]], shape=[1, 3, 3, 1])
+ query_points = constant_op.constant(
+ [[0., 0.], [1., 0.], [2., 0.5], [1.5, 1.5]], shape=[1, 4, 2])
+ expected_results = np.reshape(np.array([0., 3., 6.5, 6.]), [1, 4, 1])
+
+ interp = dense_image_warp._interpolate_bilinear(grid, query_points)
+
+ with self.test_session() as sess:
+ predicted = sess.run(interp)
+ self.assertAllClose(expected_results, predicted)
+
+ def test_interpolate_small_grid_xy(self):
+ grid = constant_op.constant(
+ [[0., 1., 2.], [3., 4., 5.], [6., 7., 8.]], shape=[1, 3, 3, 1])
+ query_points = constant_op.constant(
+ [[0., 0.], [0., 1.], [0.5, 2.0], [1.5, 1.5]], shape=[1, 4, 2])
+ expected_results = np.reshape(np.array([0., 3., 6.5, 6.]), [1, 4, 1])
+
+ interp = dense_image_warp._interpolate_bilinear(
+ grid, query_points, indexing='xy')
+
+ with self.test_session() as sess:
+ predicted = sess.run(interp)
+ self.assertAllClose(expected_results, predicted)
+
+ def test_interpolate_small_grid_batched(self):
+ grid = constant_op.constant(
+ [[[0., 1.], [3., 4.]], [[5., 6.], [7., 8.]]], shape=[2, 2, 2, 1])
+ query_points = constant_op.constant([[[0., 0.], [1., 0.], [0.5, 0.5]],
+ [[0.5, 0.], [1., 0.], [1., 1.]]])
+ expected_results = np.reshape(
+ np.array([[0., 3., 2.], [6., 7., 8.]]), [2, 3, 1])
+
+ interp = dense_image_warp._interpolate_bilinear(grid, query_points)
+
+ with self.test_session() as sess:
+ predicted = sess.run(interp)
+ self.assertAllClose(expected_results, predicted)
+
+ def get_image_and_flow_placeholders(self, shape, image_type, flow_type):
+ batch_size, height, width, numchannels = shape
+ image_shape = [batch_size, height, width, numchannels]
+ flow_shape = [batch_size, height, width, 2]
+
+ tf_type = {
+ 'float16': dtypes.half,
+ 'float32': dtypes.float32,
+ 'float64': dtypes.float64
+ }
+
+ image = array_ops.placeholder(dtype=tf_type[image_type], shape=image_shape)
+
+ flows = array_ops.placeholder(dtype=tf_type[flow_type], shape=flow_shape)
+ return image, flows
+
+ def get_random_image_and_flows(self, shape, image_type, flow_type):
+ batch_size, height, width, numchannels = shape
+ image_shape = [batch_size, height, width, numchannels]
+ image = np.random.normal(size=image_shape)
+ flow_shape = [batch_size, height, width, 2]
+ flows = np.random.normal(size=flow_shape) * 3
+ return image.astype(image_type), flows.astype(flow_type)
+
+ def assert_correct_interpolation_value(self,
+ image,
+ flows,
+ pred_interpolation,
+ batch_index,
+ y_index,
+ x_index,
+ low_precision=False):
+ """Assert that the tf interpolation matches hand-computed value."""
+
+ height = image.shape[1]
+ width = image.shape[2]
+ displacement = flows[batch_index, y_index, x_index, :]
+ float_y = y_index - displacement[0]
+ float_x = x_index - displacement[1]
+ floor_y = max(min(height - 2, math.floor(float_y)), 0)
+ floor_x = max(min(width - 2, math.floor(float_x)), 0)
+ ceil_y = floor_y + 1
+ ceil_x = floor_x + 1
+
+ alpha_y = min(max(0.0, float_y - floor_y), 1.0)
+ alpha_x = min(max(0.0, float_x - floor_x), 1.0)
+
+ floor_y = int(floor_y)
+ floor_x = int(floor_x)
+ ceil_y = int(ceil_y)
+ ceil_x = int(ceil_x)
+
+ top_left = image[batch_index, floor_y, floor_x, :]
+ top_right = image[batch_index, floor_y, ceil_x, :]
+ bottom_left = image[batch_index, ceil_y, floor_x, :]
+ bottom_right = image[batch_index, ceil_y, ceil_x, :]
+
+ interp_top = alpha_x * (top_right - top_left) + top_left
+ interp_bottom = alpha_x * (bottom_right - bottom_left) + bottom_left
+ interp = alpha_y * (interp_bottom - interp_top) + interp_top
+ atol = 1e-6
+ rtol = 1e-6
+ if low_precision:
+ atol = 1e-2
+ rtol = 1e-3
+ self.assertAllClose(
+ interp,
+ pred_interpolation[batch_index, y_index, x_index, :],
+ atol=atol,
+ rtol=rtol)
+
+ def check_zero_flow_correctness(self, shape, image_type, flow_type):
+ """Assert using zero flows doesn't change the input image."""
+
+ image, flows = self.get_image_and_flow_placeholders(shape, image_type,
+ flow_type)
+ interp = dense_image_warp.dense_image_warp(image, flows)
+
+ with self.test_session() as sess:
+ rand_image, rand_flows = self.get_random_image_and_flows(
+ shape, image_type, flow_type)
+ rand_flows *= 0
+
+ predicted_interpolation = sess.run(
+ interp, feed_dict={
+ image: rand_image,
+ flows: rand_flows
+ })
+ self.assertAllClose(rand_image, predicted_interpolation)
+
+ def test_zero_flows(self):
+ """Apply check_zero_flow_correctness() for a few sizes and types."""
+
+ shapes_to_try = [[3, 4, 5, 6], [1, 2, 2, 1]]
+ for shape in shapes_to_try:
+ self.check_zero_flow_correctness(
+ shape, image_type='float32', flow_type='float32')
+
+ def check_interpolation_correctness(self,
+ shape,
+ image_type,
+ flow_type,
+ num_probes=5):
+ """Interpolate, and then assert correctness for a few query locations."""
+
+ image, flows = self.get_image_and_flow_placeholders(shape, image_type,
+ flow_type)
+ interp = dense_image_warp.dense_image_warp(image, flows)
+ low_precision = image_type == 'float16' or flow_type == 'float16'
+ with self.test_session() as sess:
+ rand_image, rand_flows = self.get_random_image_and_flows(
+ shape, image_type, flow_type)
+
+ pred_interpolation = sess.run(
+ interp, feed_dict={
+ image: rand_image,
+ flows: rand_flows
+ })
+
+ for _ in range(num_probes):
+ batch_index = np.random.randint(0, shape[0])
+ y_index = np.random.randint(0, shape[1])
+ x_index = np.random.randint(0, shape[2])
+
+ self.assert_correct_interpolation_value(
+ rand_image,
+ rand_flows,
+ pred_interpolation,
+ batch_index,
+ y_index,
+ x_index,
+ low_precision=low_precision)
+
+ def test_interpolation(self):
+ """Apply check_interpolation_correctness() for a few sizes and types."""
+
+ shapes_to_try = [[3, 4, 5, 6], [1, 5, 5, 3], [1, 2, 2, 1]]
+ for im_type in ['float32', 'float64', 'float16']:
+ for flow_type in ['float32', 'float64', 'float16']:
+ for shape in shapes_to_try:
+ self.check_interpolation_correctness(shape, im_type, flow_type)
+
+ def test_gradients_exist(self):
+ """Check that backprop can run.
+
+ The correctness of the gradients is assumed, since the forward propagation
+ is tested to be correct and we only use built-in tf ops.
+ However, we perform a simple test to make sure that backprop can actually
+ run. We treat the flows as a tf.Variable and optimize them to minimize
+ the difference between the interpolated image and the input image.
+ """
+
+ batch_size, height, width, numchannels = [4, 5, 6, 7]
+ image_shape = [batch_size, height, width, numchannels]
+ image = random_ops.random_normal(image_shape)
+ flow_shape = [batch_size, height, width, 2]
+ init_flows = np.float32(np.random.normal(size=flow_shape) * 0.25)
+ flows = variables.Variable(init_flows)
+
+ interp = dense_image_warp.dense_image_warp(image, flows)
+ loss = math_ops.reduce_mean(math_ops.square(interp - image))
+
+ optimizer = adam.AdamOptimizer(1.0)
+ grad = gradients.gradients(loss, [flows])
+ opt_func = optimizer.apply_gradients(zip(grad, [flows]))
+ init_op = variables.global_variables_initializer()
+
+ with self.test_session() as sess:
+ sess.run(init_op)
+ for _ in range(10):
+ sess.run(opt_func)
+
+ def test_size_exception(self):
+ """Make sure it throws an exception for images that are too small."""
+
+ shape = [1, 2, 1, 1]
+ msg = 'Should have raised an exception for invalid image size'
+ with self.assertRaises(ValueError, msg=msg):
+ self.check_interpolation_correctness(shape, 'float32', 'float32')
+
+
+if __name__ == '__main__':
+ googletest.main()
diff --git a/tensorflow/contrib/image/python/kernel_tests/interpolate_spline_test.py b/tensorflow/contrib/image/python/kernel_tests/interpolate_spline_test.py
new file mode 100644
index 0000000000..1cba46e17e
--- /dev/null
+++ b/tensorflow/contrib/image/python/kernel_tests/interpolate_spline_test.py
@@ -0,0 +1,261 @@
+# Copyright 2018 The TensorFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+"""Tests for interpolate_spline."""
+
+import numpy as np
+from scipy import interpolate as sc_interpolate
+
+from tensorflow.contrib.image.python.ops import interpolate_spline
+
+from tensorflow.python.framework import constant_op
+from tensorflow.python.framework import ops
+from tensorflow.python.framework import test_util
+
+from tensorflow.python.ops import clip_ops
+from tensorflow.python.ops import gradients
+from tensorflow.python.ops import math_ops
+from tensorflow.python.ops import variables
+from tensorflow.python.platform import googletest
+
+from tensorflow.python.training import momentum
+
+
+class _InterpolationProblem(object):
+ """Abstract class for interpolation problem descriptions."""
+
+ def get_problem(self, optimizable=False, extrapolate=True, dtype='float32'):
+ """Make data for an interpolation problem where all x vectors are n-d.
+
+ Args:
+ optimizable: If True, then make train_points a tf.Variable.
+ extrapolate: If False, then clamp the query_points values to be within
+ the max and min of train_points.
+ dtype: The data type to use.
+
+ Returns:
+ query_points, query_values, train_points, train_values: training and
+ test tensors for interpolation problem
+ """
+
+ # The values generated here depend on a seed of 0.
+ np.random.seed(0)
+
+ batch_size = 1
+ num_training_points = 10
+ num_query_points = 4
+
+ init_points = np.random.uniform(
+ size=[batch_size, num_training_points, self.DATA_DIM])
+
+ init_points = init_points.astype(dtype)
+ train_points = (
+ variables.Variable(init_points)
+ if optimizable else constant_op.constant(init_points))
+ train_values = self.tf_function(train_points)
+
+ query_points_np = np.random.uniform(
+ size=[batch_size, num_query_points, self.DATA_DIM])
+ query_points_np = query_points_np.astype(dtype)
+ if not extrapolate:
+ query_points_np = np.clip(query_points_np, np.min(init_points),
+ np.max(init_points))
+
+ query_points = constant_op.constant(query_points_np)
+ query_values = self.np_function(query_points_np)
+
+ return query_points, query_values, train_points, train_values
+
+
+class _QuadraticPlusSinProblem1D(_InterpolationProblem):
+ """1D interpolation problem used for regression testing."""
+ DATA_DIM = 1
+ HARDCODED_QUERY_VALUES = {
+ (1.0, 0.0): [6.2647187603, -7.84362604077, -5.63690142322, 1.42928896387],
+ (1.0,
+ 0.01): [6.77688289946, -8.02163669853, -5.79491157027, 1.4063285693],
+ (2.0,
+ 0.0): [8.67110264937, -8.41281390883, -5.80190044693, 1.50155606059],
+ (2.0,
+ 0.01): [6.70797816797, -7.49709587663, -5.28965776238, 1.52284731741],
+ (3.0,
+ 0.0): [9.37691802935, -8.50390141515, -5.80786417426, 1.63467762122],
+ (3.0,
+ 0.01): [4.47106304758, -5.71266128361, -3.92529303296, 1.86755293857],
+ (4.0,
+ 0.0): [9.58172461111, -8.51432104771, -5.80967675388, 1.63361164256],
+ (4.0, 0.01): [
+ -3.87902711352, -0.0253462273846, 1.79857618022, -0.769339675725
+ ]
+ }
+
+ def np_function(self, x):
+ """Takes np array, evaluates the test function, and returns np array."""
+ return np.sum(
+ np.power((x - 0.5), 3) - 0.25 * x + 10 * np.sin(x * 10),
+ axis=2,
+ keepdims=True)
+
+ def tf_function(self, x):
+ """Takes tf tensor, evaluates the test function, and returns tf tensor."""
+ return math_ops.reduce_mean(
+ math_ops.pow((x - 0.5), 3) - 0.25 * x + 10 * math_ops.sin(x * 10),
+ 2,
+ keepdims=True)
+
+
+class _QuadraticPlusSinProblemND(_InterpolationProblem):
+ """3D interpolation problem used for regression testing."""
+
+ DATA_DIM = 3
+ HARDCODED_QUERY_VALUES = {
+ (1.0, 0.0): [1.06609663962, 1.28894849357, 1.10882405595, 1.63966936885],
+ (1.0, 0.01): [1.03123780748, 1.2952930985, 1.10366822954, 1.65265118569],
+ (2.0, 0.0): [0.627787735064, 1.43802857251, 1.00194632358, 1.91667538215],
+ (2.0, 0.01): [0.730159985046, 1.41702471595, 1.0065827217, 1.85758519312],
+ (3.0, 0.0): [0.350460417862, 1.67223539464, 1.00475331246, 2.31580322491],
+ (3.0,
+ 0.01): [0.624557250556, 1.63138876667, 0.976588193162, 2.12511237866],
+ (4.0,
+ 0.0): [0.898129669986, 1.24434133638, -0.938056116931, 1.59910338833],
+ (4.0,
+ 0.01): [0.0930360338179, -3.38791305538, -1.00969032567, 0.745535080382],
+ }
+
+ def np_function(self, x):
+ """Takes np array, evaluates the test function, and returns np array."""
+ return np.sum(
+ np.square(x - 0.5) + 0.25 * x + 1 * np.sin(x * 15),
+ axis=2,
+ keepdims=True)
+
+ def tf_function(self, x):
+ """Takes tf tensor, evaluates the test function, and returns tf tensor."""
+ return math_ops.reduce_sum(
+ math_ops.square(x - 0.5) + 0.25 * x + 1 * math_ops.sin(x * 15),
+ 2,
+ keepdims=True)
+
+
+class InterpolateSplineTest(test_util.TensorFlowTestCase):
+
+ def test_1d_linear_interpolation(self):
+ """For 1d linear interpolation, we can compare directly to scipy."""
+
+ tp = _QuadraticPlusSinProblem1D()
+ (query_points, _, train_points, train_values) = tp.get_problem(
+ extrapolate=False, dtype='float64')
+ interpolation_order = 1
+
+ with ops.name_scope('interpolator'):
+ interpolator = interpolate_spline.interpolate_spline(
+ train_points, train_values, query_points, interpolation_order)
+ with self.test_session() as sess:
+ fetches = [query_points, train_points, train_values, interpolator]
+ query_points_, train_points_, train_values_, interp_ = sess.run(fetches)
+
+ # Just look at the first element of the minibatch.
+ # Also, trim the final singleton dimension.
+ interp_ = interp_[0, :, 0]
+ query_points_ = query_points_[0, :, 0]
+ train_points_ = train_points_[0, :, 0]
+ train_values_ = train_values_[0, :, 0]
+
+ # Compute scipy interpolation.
+ scipy_interp_function = sc_interpolate.interp1d(
+ train_points_, train_values_, kind='linear')
+
+ scipy_interpolation = scipy_interp_function(query_points_)
+ scipy_interpolation_on_train = scipy_interp_function(train_points_)
+
+ # Even with float64 precision, the interpolants disagree with scipy a
+ # bit due to the fact that we add the EPSILON to prevent sqrt(0), etc.
+ tol = 1e-3
+
+ self.assertAllClose(
+ train_values_, scipy_interpolation_on_train, atol=tol, rtol=tol)
+ self.assertAllClose(interp_, scipy_interpolation, atol=tol, rtol=tol)
+
+ def test_1d_interpolation(self):
+ """Regression test for interpolation with 1-D points."""
+
+ tp = _QuadraticPlusSinProblem1D()
+ (query_points, _, train_points,
+ train_values) = tp.get_problem(dtype='float64')
+
+ for order in (1, 2, 3):
+ for reg_weight in (0, 0.01):
+ interpolator = interpolate_spline.interpolate_spline(
+ train_points, train_values, query_points, order, reg_weight)
+
+ target_interpolation = tp.HARDCODED_QUERY_VALUES[(order, reg_weight)]
+ target_interpolation = np.array(target_interpolation)
+ with self.test_session() as sess:
+ interp_val = sess.run(interpolator)
+ self.assertAllClose(interp_val[0, :, 0], target_interpolation)
+
+ def test_nd_linear_interpolation(self):
+ """Regression test for interpolation with N-D points."""
+
+ tp = _QuadraticPlusSinProblemND()
+ (query_points, _, train_points,
+ train_values) = tp.get_problem(dtype='float64')
+
+ for order in (1, 2, 3):
+ for reg_weight in (0, 0.01):
+ interpolator = interpolate_spline.interpolate_spline(
+ train_points, train_values, query_points, order, reg_weight)
+
+ target_interpolation = tp.HARDCODED_QUERY_VALUES[(order, reg_weight)]
+ target_interpolation = np.array(target_interpolation)
+ with self.test_session() as sess:
+ interp_val = sess.run(interpolator)
+ self.assertAllClose(interp_val[0, :, 0], target_interpolation)
+
+ def test_interpolation_gradient(self):
+ """Make sure that backprop can run. Correctness of gradients is assumed.
+
+ Here, we create a use a small 'training' set and a more densely-sampled
+ set of query points, for which we know the true value in advance. The goal
+ is to choose x locations for the training data such that interpolating using
+ this training data yields the best reconstruction for the function
+ values at the query points. The training data locations are optimized
+ iteratively using gradient descent.
+ """
+ tp = _QuadraticPlusSinProblemND()
+ (query_points, query_values, train_points,
+ train_values) = tp.get_problem(optimizable=True)
+
+ regularization = 0.001
+ for interpolation_order in (1, 2, 3, 4):
+ interpolator = interpolate_spline.interpolate_spline(
+ train_points, train_values, query_points, interpolation_order,
+ regularization)
+
+ loss = math_ops.reduce_mean(math_ops.square(query_values - interpolator))
+
+ optimizer = momentum.MomentumOptimizer(0.001, 0.9)
+ grad = gradients.gradients(loss, [train_points])
+ grad, _ = clip_ops.clip_by_global_norm(grad, 1.0)
+ opt_func = optimizer.apply_gradients(zip(grad, [train_points]))
+ init_op = variables.global_variables_initializer()
+
+ with self.test_session() as sess:
+ sess.run(init_op)
+ for _ in range(100):
+ sess.run([loss, opt_func])
+
+
+if __name__ == '__main__':
+ googletest.main()
diff --git a/tensorflow/contrib/image/python/kernel_tests/sparse_image_warp_test.py b/tensorflow/contrib/image/python/kernel_tests/sparse_image_warp_test.py
new file mode 100644
index 0000000000..017969d230
--- /dev/null
+++ b/tensorflow/contrib/image/python/kernel_tests/sparse_image_warp_test.py
@@ -0,0 +1,251 @@
+# Copyright 2018 The TensorFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+"""Tests for sparse_image_warp."""
+
+import numpy as np
+
+from tensorflow.contrib.image.python.ops import sparse_image_warp
+
+from tensorflow.python.framework import constant_op
+from tensorflow.python.framework import dtypes
+from tensorflow.python.framework import test_util
+from tensorflow.python.ops import clip_ops
+from tensorflow.python.ops import gradients
+from tensorflow.python.ops import image_ops
+from tensorflow.python.ops import io_ops
+from tensorflow.python.ops import math_ops
+from tensorflow.python.ops import variables
+from tensorflow.python.platform import googletest
+from tensorflow.python.platform import test
+
+from tensorflow.python.training import momentum
+
+
+class SparseImageWarpTest(test_util.TensorFlowTestCase):
+
+ def setUp(self):
+ np.random.seed(0)
+
+ def testGetBoundaryLocations(self):
+ image_height = 11
+ image_width = 11
+ num_points_per_edge = 4
+ locs = sparse_image_warp._get_boundary_locations(image_height, image_width,
+ num_points_per_edge)
+ num_points = locs.shape[0]
+ self.assertEqual(num_points, 4 + 4 * num_points_per_edge)
+ locs = [(locs[i, 0], locs[i, 1]) for i in range(num_points)]
+ for i in (0, image_height - 1):
+ for j in (0, image_width - 1):
+ self.assertIn((i, j), locs, '{},{} not in the locations'.format(i, j))
+
+ for i in (2, 4, 6, 8):
+ for j in (0, image_width - 1):
+ self.assertIn((i, j), locs, '{},{} not in the locations'.format(i, j))
+
+ for i in (0, image_height - 1):
+ for j in (2, 4, 6, 8):
+ self.assertIn((i, j), locs, '{},{} not in the locations'.format(i, j))
+
+ def testGetGridLocations(self):
+ image_height = 5
+ image_width = 3
+ grid = sparse_image_warp._get_grid_locations(image_height, image_width)
+ for i in range(image_height):
+ for j in range(image_width):
+ self.assertEqual(grid[i, j, 0], i)
+ self.assertEqual(grid[i, j, 1], j)
+
+ def testZeroShift(self):
+ """Run assertZeroShift for various hyperparameters."""
+ for order in (1, 2):
+ for regularization in (0, 0.01):
+ for num_boundary_points in (0, 1):
+ self.assertZeroShift(order, regularization, num_boundary_points)
+
+ def assertZeroShift(self, order, regularization, num_boundary_points):
+ """Check that warping with zero displacements doesn't change the image."""
+ batch_size = 1
+ image_height = 4
+ image_width = 4
+ channels = 3
+
+ image = np.random.uniform(
+ size=[batch_size, image_height, image_width, channels])
+
+ input_image_op = constant_op.constant(np.float32(image))
+
+ control_point_locations = [[1., 1.], [2., 2.], [2., 1.]]
+ control_point_locations = constant_op.constant(
+ np.float32(np.expand_dims(control_point_locations, 0)))
+
+ control_point_displacements = np.zeros(
+ control_point_locations.shape.as_list())
+ control_point_displacements = constant_op.constant(
+ np.float32(control_point_displacements))
+
+ (warped_image_op, flow_field) = sparse_image_warp.sparse_image_warp(
+ input_image_op,
+ control_point_locations,
+ control_point_locations + control_point_displacements,
+ interpolation_order=order,
+ regularization_weight=regularization,
+ num_boundary_points=num_boundary_points)
+
+ with self.test_session() as sess:
+ warped_image, input_image, _ = sess.run(
+ [warped_image_op, input_image_op, flow_field])
+
+ self.assertAllClose(warped_image, input_image)
+
+ def testMoveSinglePixel(self):
+ """Run assertMoveSinglePixel for various hyperparameters and data types."""
+ for order in (1, 2):
+ for num_boundary_points in (1, 2):
+ for type_to_use in (dtypes.float32, dtypes.float64):
+ self.assertMoveSinglePixel(order, num_boundary_points, type_to_use)
+
+ def assertMoveSinglePixel(self, order, num_boundary_points, type_to_use):
+ """Move a single block in a small grid using warping."""
+ batch_size = 1
+ image_height = 7
+ image_width = 7
+ channels = 3
+
+ image = np.zeros([batch_size, image_height, image_width, channels])
+ image[:, 3, 3, :] = 1.0
+ input_image_op = constant_op.constant(image, dtype=type_to_use)
+
+ # Place a control point at the one white pixel.
+ control_point_locations = [[3., 3.]]
+ control_point_locations = constant_op.constant(
+ np.float32(np.expand_dims(control_point_locations, 0)),
+ dtype=type_to_use)
+ # Shift it one pixel to the right.
+ control_point_displacements = [[0., 1.0]]
+ control_point_displacements = constant_op.constant(
+ np.float32(np.expand_dims(control_point_displacements, 0)),
+ dtype=type_to_use)
+
+ (warped_image_op, flow_field) = sparse_image_warp.sparse_image_warp(
+ input_image_op,
+ control_point_locations,
+ control_point_locations + control_point_displacements,
+ interpolation_order=order,
+ num_boundary_points=num_boundary_points)
+
+ with self.test_session() as sess:
+ warped_image, input_image, flow = sess.run(
+ [warped_image_op, input_image_op, flow_field])
+ # Check that it moved the pixel correctly.
+ self.assertAllClose(
+ warped_image[0, 4, 5, :],
+ input_image[0, 4, 4, :],
+ atol=1e-5,
+ rtol=1e-5)
+
+ # Test that there is no flow at the corners.
+ for i in (0, image_height - 1):
+ for j in (0, image_width - 1):
+ self.assertAllClose(
+ flow[0, i, j, :], np.zeros([2]), atol=1e-5, rtol=1e-5)
+
+ def load_image(self, image_file, sess):
+ image_op = image_ops.decode_png(
+ io_ops.read_file(image_file), dtype=dtypes.uint8, channels=4)[:, :, 0:3]
+ return sess.run(image_op)
+
+ def testSmileyFace(self):
+ """Check warping accuracy by comparing to hardcoded warped images."""
+
+ test_data_dir = test.test_src_dir_path('contrib/image/python/'
+ 'kernel_tests/test_data/')
+ input_file = test_data_dir + 'Yellow_Smiley_Face.png'
+ with self.test_session() as sess:
+ input_image = self.load_image(input_file, sess)
+ control_points = np.asarray([[64, 59], [180 - 64, 59], [39, 111],
+ [180 - 39, 111], [90, 143], [58, 134],
+ [180 - 58, 134]]) # pyformat: disable
+ control_point_displacements = np.asarray(
+ [[-10.5, 10.5], [10.5, 10.5], [0, 0], [0, 0], [0, -10], [-20, 10.25],
+ [10, 10.75]])
+ control_points_op = constant_op.constant(
+ np.expand_dims(np.float32(control_points[:, [1, 0]]), 0))
+ control_point_displacements_op = constant_op.constant(
+ np.expand_dims(np.float32(control_point_displacements[:, [1, 0]]), 0))
+ float_image = np.expand_dims(np.float32(input_image) / 255, 0)
+ input_image_op = constant_op.constant(float_image)
+
+ for interpolation_order in (1, 2, 3):
+ for num_boundary_points in (0, 1, 4):
+ warp_op, _ = sparse_image_warp.sparse_image_warp(
+ input_image_op,
+ control_points_op,
+ control_points_op + control_point_displacements_op,
+ interpolation_order=interpolation_order,
+ num_boundary_points=num_boundary_points)
+ with self.test_session() as sess:
+ warped_image = sess.run(warp_op)
+ out_image = np.uint8(warped_image[0, :, :, :] * 255)
+ target_file = (
+ test_data_dir +
+ 'Yellow_Smiley_Face_Warp-interp' + '-{}-clamp-{}.png'.format(
+ interpolation_order, num_boundary_points))
+
+ target_image = self.load_image(target_file, sess)
+
+ # Check that the target_image and out_image difference is no
+ # bigger than 1 (on a scale of 0-255). Due to differences in
+ # floating point computation on different devices, the float
+ # output in warped_image may get rounded to a different int
+ # than that in the saved png file loaded into target_image.
+ self.assertAllClose(target_image, out_image, atol=1, rtol=1e-3)
+
+ def testThatBackpropRuns(self):
+ """Run optimization to ensure that gradients can be computed."""
+
+ batch_size = 1
+ image_height = 9
+ image_width = 12
+ image = variables.Variable(
+ np.float32(
+ np.random.uniform(size=[batch_size, image_height, image_width, 3])))
+ control_point_locations = [[3., 3.]]
+ control_point_locations = constant_op.constant(
+ np.float32(np.expand_dims(control_point_locations, 0)))
+ control_point_displacements = [[0.25, -0.5]]
+ control_point_displacements = constant_op.constant(
+ np.float32(np.expand_dims(control_point_displacements, 0)))
+ warped_image, _ = sparse_image_warp.sparse_image_warp(
+ image,
+ control_point_locations,
+ control_point_locations + control_point_displacements,
+ num_boundary_points=3)
+
+ loss = math_ops.reduce_mean(math_ops.abs(warped_image - image))
+ optimizer = momentum.MomentumOptimizer(0.001, 0.9)
+ grad = gradients.gradients(loss, [image])
+ grad, _ = clip_ops.clip_by_global_norm(grad, 1.0)
+ opt_func = optimizer.apply_gradients(zip(grad, [image]))
+ init_op = variables.global_variables_initializer()
+
+ with self.test_session() as sess:
+ sess.run(init_op)
+ for _ in range(5):
+ sess.run([loss, opt_func])
+
+
+if __name__ == '__main__':
+ googletest.main()
diff --git a/tensorflow/contrib/image/python/kernel_tests/test_data/Yellow_Smiley_Face.png b/tensorflow/contrib/image/python/kernel_tests/test_data/Yellow_Smiley_Face.png
new file mode 100644
index 0000000000..7e303881e2
--- /dev/null
+++ b/tensorflow/contrib/image/python/kernel_tests/test_data/Yellow_Smiley_Face.png
Binary files differ
diff --git a/tensorflow/contrib/image/python/kernel_tests/test_data/Yellow_Smiley_Face_Warp-interp-1-clamp-0.png b/tensorflow/contrib/image/python/kernel_tests/test_data/Yellow_Smiley_Face_Warp-interp-1-clamp-0.png
new file mode 100644
index 0000000000..7fd9e4e6d6
--- /dev/null
+++ b/tensorflow/contrib/image/python/kernel_tests/test_data/Yellow_Smiley_Face_Warp-interp-1-clamp-0.png
Binary files differ
diff --git a/tensorflow/contrib/image/python/kernel_tests/test_data/Yellow_Smiley_Face_Warp-interp-1-clamp-1.png b/tensorflow/contrib/image/python/kernel_tests/test_data/Yellow_Smiley_Face_Warp-interp-1-clamp-1.png
new file mode 100644
index 0000000000..86d225e5d2
--- /dev/null
+++ b/tensorflow/contrib/image/python/kernel_tests/test_data/Yellow_Smiley_Face_Warp-interp-1-clamp-1.png
Binary files differ
diff --git a/tensorflow/contrib/image/python/kernel_tests/test_data/Yellow_Smiley_Face_Warp-interp-1-clamp-4.png b/tensorflow/contrib/image/python/kernel_tests/test_data/Yellow_Smiley_Face_Warp-interp-1-clamp-4.png
new file mode 100644
index 0000000000..37e8ffae11
--- /dev/null
+++ b/tensorflow/contrib/image/python/kernel_tests/test_data/Yellow_Smiley_Face_Warp-interp-1-clamp-4.png
Binary files differ
diff --git a/tensorflow/contrib/image/python/kernel_tests/test_data/Yellow_Smiley_Face_Warp-interp-2-clamp-0.png b/tensorflow/contrib/image/python/kernel_tests/test_data/Yellow_Smiley_Face_Warp-interp-2-clamp-0.png
new file mode 100644
index 0000000000..e49b581612
--- /dev/null
+++ b/tensorflow/contrib/image/python/kernel_tests/test_data/Yellow_Smiley_Face_Warp-interp-2-clamp-0.png
Binary files differ
diff --git a/tensorflow/contrib/image/python/kernel_tests/test_data/Yellow_Smiley_Face_Warp-interp-2-clamp-1.png b/tensorflow/contrib/image/python/kernel_tests/test_data/Yellow_Smiley_Face_Warp-interp-2-clamp-1.png
new file mode 100644
index 0000000000..df3cf20043
--- /dev/null
+++ b/tensorflow/contrib/image/python/kernel_tests/test_data/Yellow_Smiley_Face_Warp-interp-2-clamp-1.png
Binary files differ
diff --git a/tensorflow/contrib/image/python/kernel_tests/test_data/Yellow_Smiley_Face_Warp-interp-2-clamp-4.png b/tensorflow/contrib/image/python/kernel_tests/test_data/Yellow_Smiley_Face_Warp-interp-2-clamp-4.png
new file mode 100644
index 0000000000..e1799a87c8
--- /dev/null
+++ b/tensorflow/contrib/image/python/kernel_tests/test_data/Yellow_Smiley_Face_Warp-interp-2-clamp-4.png
Binary files differ
diff --git a/tensorflow/contrib/image/python/kernel_tests/test_data/Yellow_Smiley_Face_Warp-interp-3-clamp-0.png b/tensorflow/contrib/image/python/kernel_tests/test_data/Yellow_Smiley_Face_Warp-interp-3-clamp-0.png
new file mode 100644
index 0000000000..2c346e0ce5
--- /dev/null
+++ b/tensorflow/contrib/image/python/kernel_tests/test_data/Yellow_Smiley_Face_Warp-interp-3-clamp-0.png
Binary files differ
diff --git a/tensorflow/contrib/image/python/kernel_tests/test_data/Yellow_Smiley_Face_Warp-interp-3-clamp-1.png b/tensorflow/contrib/image/python/kernel_tests/test_data/Yellow_Smiley_Face_Warp-interp-3-clamp-1.png
new file mode 100644
index 0000000000..6f8b65451c
--- /dev/null
+++ b/tensorflow/contrib/image/python/kernel_tests/test_data/Yellow_Smiley_Face_Warp-interp-3-clamp-1.png
Binary files differ
diff --git a/tensorflow/contrib/image/python/kernel_tests/test_data/Yellow_Smiley_Face_Warp-interp-3-clamp-4.png b/tensorflow/contrib/image/python/kernel_tests/test_data/Yellow_Smiley_Face_Warp-interp-3-clamp-4.png
new file mode 100644
index 0000000000..8e78146d95
--- /dev/null
+++ b/tensorflow/contrib/image/python/kernel_tests/test_data/Yellow_Smiley_Face_Warp-interp-3-clamp-4.png
Binary files differ
diff --git a/tensorflow/contrib/image/python/ops/dense_image_warp.py b/tensorflow/contrib/image/python/ops/dense_image_warp.py
new file mode 100644
index 0000000000..9403003be1
--- /dev/null
+++ b/tensorflow/contrib/image/python/ops/dense_image_warp.py
@@ -0,0 +1,196 @@
+# Copyright 2018 The TensorFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+"""Image warping using per-pixel flow vectors."""
+
+import numpy as np
+
+from tensorflow.python.framework import constant_op
+from tensorflow.python.framework import dtypes
+from tensorflow.python.framework import ops
+
+from tensorflow.python.ops import array_ops
+from tensorflow.python.ops import math_ops
+
+
+def _interpolate_bilinear(grid,
+ query_points,
+ name='interpolate_bilinear',
+ indexing='ij'):
+ """Similar to Matlab's interp2 function.
+
+ Finds values for query points on a grid using bilinear interpolation.
+
+ Args:
+ grid: a 4-D float `Tensor` of shape `[batch, height, width, channels]`.
+ query_points: a 3-D float `Tensor` of N points with shape `[batch, N, 2]`.
+ name: a name for the operation (optional).
+ indexing: whether the query points are specified as row and column (ij),
+ or Cartesian coordinates (xy).
+
+ Returns:
+ values: a 3-D `Tensor` with shape `[batch, N, channels]`
+
+ Raises:
+ ValueError: if the indexing mode is invalid, or if the shape of the inputs
+ invalid.
+ """
+ if indexing != 'ij' and indexing != 'xy':
+ raise ValueError('Indexing mode must be \'ij\' or \'xy\'')
+
+ with ops.name_scope(name):
+ shape = grid.get_shape().as_list()
+ if len(shape) != 4:
+ msg = 'Grid must be 4 dimensional. Received size: '
+ raise ValueError(msg + str(grid.get_shape()))
+
+ batch_size, height, width, channels = shape
+ query_type = query_points.dtype
+ grid_type = grid.dtype
+
+ if (len(query_points.get_shape()) != 3 or
+ query_points.get_shape()[2].value != 2):
+ msg = ('Query points must be 3 dimensional and size 2 in dim 2. Received '
+ 'size: ')
+ raise ValueError(msg + str(query_points.get_shape()))
+
+ _, num_queries, _ = query_points.get_shape().as_list()
+
+ if height < 2 or width < 2:
+ msg = 'Grid must be at least batch_size x 2 x 2 in size. Received size: '
+ raise ValueError(msg + str(grid.get_shape()))
+
+ alphas = []
+ floors = []
+ ceils = []
+
+ index_order = [0, 1] if indexing == 'ij' else [1, 0]
+ unstacked_query_points = array_ops.unstack(query_points, axis=2)
+
+ for dim in index_order:
+ with ops.name_scope('dim-' + str(dim)):
+ queries = unstacked_query_points[dim]
+
+ size_in_indexing_dimension = shape[dim + 1]
+
+ # max_floor is size_in_indexing_dimension - 2 so that max_floor + 1
+ # is still a valid index into the grid.
+ max_floor = math_ops.cast(size_in_indexing_dimension - 2, query_type)
+ min_floor = constant_op.constant(0.0, dtype=query_type)
+ floor = math_ops.minimum(
+ math_ops.maximum(min_floor, math_ops.floor(queries)), max_floor)
+ int_floor = math_ops.cast(floor, dtypes.int32)
+ floors.append(int_floor)
+ ceil = int_floor + 1
+ ceils.append(ceil)
+
+ # alpha has the same type as the grid, as we will directly use alpha
+ # when taking linear combinations of pixel values from the image.
+ alpha = math_ops.cast(queries - floor, grid_type)
+ min_alpha = constant_op.constant(0.0, dtype=grid_type)
+ max_alpha = constant_op.constant(1.0, dtype=grid_type)
+ alpha = math_ops.minimum(math_ops.maximum(min_alpha, alpha), max_alpha)
+
+ # Expand alpha to [b, n, 1] so we can use broadcasting
+ # (since the alpha values don't depend on the channel).
+ alpha = array_ops.expand_dims(alpha, 2)
+ alphas.append(alpha)
+
+ if batch_size * height * width > np.iinfo(np.int32).max / 8:
+ error_msg = """The image size or batch size is sufficiently large
+ that the linearized addresses used by array_ops.gather
+ may exceed the int32 limit."""
+ raise ValueError(error_msg)
+
+ flattened_grid = array_ops.reshape(grid,
+ [batch_size * height * width, channels])
+ batch_offsets = array_ops.reshape(
+ math_ops.range(batch_size) * height * width, [batch_size, 1])
+
+ # This wraps array_ops.gather. We reshape the image data such that the
+ # batch, y, and x coordinates are pulled into the first dimension.
+ # Then we gather. Finally, we reshape the output back. It's possible this
+ # code would be made simpler by using array_ops.gather_nd.
+ def gather(y_coords, x_coords, name):
+ with ops.name_scope('gather-' + name):
+ linear_coordinates = batch_offsets + y_coords * width + x_coords
+ gathered_values = array_ops.gather(flattened_grid, linear_coordinates)
+ return array_ops.reshape(gathered_values,
+ [batch_size, num_queries, channels])
+
+ # grab the pixel values in the 4 corners around each query point
+ top_left = gather(floors[0], floors[1], 'top_left')
+ top_right = gather(floors[0], ceils[1], 'top_right')
+ bottom_left = gather(ceils[0], floors[1], 'bottom_left')
+ bottom_right = gather(ceils[0], ceils[1], 'bottom_right')
+
+ # now, do the actual interpolation
+ with ops.name_scope('interpolate'):
+ interp_top = alphas[1] * (top_right - top_left) + top_left
+ interp_bottom = alphas[1] * (bottom_right - bottom_left) + bottom_left
+ interp = alphas[0] * (interp_bottom - interp_top) + interp_top
+
+ return interp
+
+
+def dense_image_warp(image, flow, name='dense_image_warp'):
+ """Image warping using per-pixel flow vectors.
+
+ Apply a non-linear warp to the image, where the warp is specified by a dense
+ flow field of offset vectors that define the correspondences of pixel values
+ in the output image back to locations in the source image. Specifically, the
+ pixel value at output[b, j, i, c] is
+ images[b, j - flow[b, j, i, 0], i - flow[b, j, i, 1], c].
+
+ The locations specified by this formula do not necessarily map to an int
+ index. Therefore, the pixel value is obtained by bilinear
+ interpolation of the 4 nearest pixels around
+ (b, j - flow[b, j, i, 0], i - flow[b, j, i, 1]). For locations outside
+ of the image, we use the nearest pixel values at the image boundary.
+
+
+ Args:
+ image: 4-D float `Tensor` with shape `[batch, height, width, channels]`.
+ flow: A 4-D float `Tensor` with shape `[batch, height, width, 2]`.
+ name: A name for the operation (optional).
+
+ Note that image and flow can be of type tf.half, tf.float32, or tf.float64,
+ and do not necessarily have to be the same type.
+
+ Returns:
+ A 4-D float `Tensor` with shape`[batch, height, width, channels]`
+ and same type as input image.
+
+ Raises:
+ ValueError: if height < 2 or width < 2 or the inputs have the wrong number
+ of dimensions.
+ """
+ with ops.name_scope(name):
+ batch_size, height, width, channels = image.get_shape().as_list()
+ # The flow is defined on the image grid. Turn the flow into a list of query
+ # points in the grid space.
+ grid_x, grid_y = array_ops.meshgrid(
+ math_ops.range(width), math_ops.range(height))
+ stacked_grid = math_ops.cast(
+ array_ops.stack([grid_y, grid_x], axis=2), flow.dtype)
+ batched_grid = array_ops.expand_dims(stacked_grid, axis=0)
+ query_points_on_grid = batched_grid - flow
+ query_points_flattened = array_ops.reshape(query_points_on_grid,
+ [batch_size, height * width, 2])
+ # Compute values at the query points, then reshape the result back to the
+ # image grid.
+ interpolated = _interpolate_bilinear(image, query_points_flattened)
+ interpolated = array_ops.reshape(interpolated,
+ [batch_size, height, width, channels])
+ return interpolated
diff --git a/tensorflow/contrib/image/python/ops/interpolate_spline.py b/tensorflow/contrib/image/python/ops/interpolate_spline.py
new file mode 100644
index 0000000000..ad17921991
--- /dev/null
+++ b/tensorflow/contrib/image/python/ops/interpolate_spline.py
@@ -0,0 +1,285 @@
+# Copyright 2018 The TensorFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+"""Polyharmonic spline interpolation."""
+
+import numpy as np
+
+from tensorflow.python.framework import constant_op
+from tensorflow.python.framework import ops
+from tensorflow.python.ops import array_ops
+from tensorflow.python.ops import linalg_ops
+from tensorflow.python.ops import math_ops
+
+EPSILON = 0.0000000001
+
+
+def _cross_squared_distance_matrix(x, y):
+ """Pairwise squared distance between two (batch) matrices' rows (2nd dim).
+
+ Computes the pairwise distances between rows of x and rows of y
+ Args:
+ x: [batch_size, n, d] float `Tensor`
+ y: [batch_size, m, d] float `Tensor`
+
+ Returns:
+ squared_dists: [batch_size, n, m] float `Tensor`, where
+ squared_dists[b,i,j] = ||x[b,i,:] - y[b,j,:]||^2
+ """
+ x_norm_squared = math_ops.reduce_sum(math_ops.square(x), 2)
+ y_norm_squared = math_ops.reduce_sum(math_ops.square(y), 2)
+
+ # Expand so that we can broadcast.
+ x_norm_squared_tile = array_ops.expand_dims(x_norm_squared, 2)
+ y_norm_squared_tile = array_ops.expand_dims(y_norm_squared, 1)
+
+ x_y_transpose = math_ops.matmul(x, y, adjoint_b=True)
+
+ # squared_dists[b,i,j] = ||x_bi - y_bj||^2 = x_bi'x_bi- 2x_bi'x_bj + x_bj'x_bj
+ squared_dists = x_norm_squared_tile - 2 * x_y_transpose + y_norm_squared_tile
+
+ return squared_dists
+
+
+def _pairwise_squared_distance_matrix(x):
+ """Pairwise squared distance among a (batch) matrix's rows (2nd dim).
+
+ This saves a bit of computation vs. using _cross_squared_distance_matrix(x,x)
+
+ Args:
+ x: `[batch_size, n, d]` float `Tensor`
+
+ Returns:
+ squared_dists: `[batch_size, n, n]` float `Tensor`, where
+ squared_dists[b,i,j] = ||x[b,i,:] - x[b,j,:]||^2
+ """
+
+ x_x_transpose = math_ops.matmul(x, x, adjoint_b=True)
+ x_norm_squared = array_ops.matrix_diag_part(x_x_transpose)
+ x_norm_squared_tile = array_ops.expand_dims(x_norm_squared, 2)
+
+ # squared_dists[b,i,j] = ||x_bi - x_bj||^2 = x_bi'x_bi- 2x_bi'x_bj + x_bj'x_bj
+ squared_dists = x_norm_squared_tile - 2 * x_x_transpose + array_ops.transpose(
+ x_norm_squared_tile, [0, 2, 1])
+
+ return squared_dists
+
+
+def _solve_interpolation(train_points, train_values, order,
+ regularization_weight):
+ """Solve for interpolation coefficients.
+
+ Computes the coefficients of the polyharmonic interpolant for the 'training'
+ data defined by (train_points, train_values) using the kernel phi.
+
+ Args:
+ train_points: `[b, n, d]` interpolation centers
+ train_values: `[b, n, k]` function values
+ order: order of the interpolation
+ regularization_weight: weight to place on smoothness regularization term
+
+ Returns:
+ w: `[b, n, k]` weights on each interpolation center
+ v: `[b, d, k]` weights on each input dimension
+ """
+
+ b, n, d = train_points.get_shape().as_list()
+ _, _, k = train_values.get_shape().as_list()
+
+ # First, rename variables so that the notation (c, f, w, v, A, B, etc.)
+ # follows https://en.wikipedia.org/wiki/Polyharmonic_spline.
+ # To account for python style guidelines we use
+ # matrix_a for A and matrix_b for B.
+
+ c = train_points
+ f = train_values
+
+ # Next, construct the linear system.
+ with ops.name_scope('construct_linear_system'):
+
+ matrix_a = _phi(_pairwise_squared_distance_matrix(c), order) # [b, n, n]
+ if regularization_weight > 0:
+ batch_identity_matrix = np.expand_dims(np.eye(n), 0)
+ batch_identity_matrix = constant_op.constant(
+ batch_identity_matrix, dtype=train_points.dtype)
+
+ matrix_a += regularization_weight * batch_identity_matrix
+
+ # Append ones to the feature values for the bias term in the linear model.
+ ones = array_ops.ones([b, n, 1], train_points.dtype)
+ matrix_b = array_ops.concat([c, ones], 2) # [b, n, d + 1]
+
+ # [b, n + d + 1, n]
+ left_block = array_ops.concat(
+ [matrix_a, array_ops.transpose(matrix_b, [0, 2, 1])], 1)
+
+ num_b_cols = matrix_b.get_shape()[2] # d + 1
+ lhs_zeros = array_ops.zeros([b, num_b_cols, num_b_cols], train_points.dtype)
+ right_block = array_ops.concat([matrix_b, lhs_zeros],
+ 1) # [b, n + d + 1, d + 1]
+ lhs = array_ops.concat([left_block, right_block],
+ 2) # [b, n + d + 1, n + d + 1]
+
+ rhs_zeros = array_ops.zeros([b, d + 1, k], train_points.dtype)
+ rhs = array_ops.concat([f, rhs_zeros], 1) # [b, n + d + 1, k]
+
+ # Then, solve the linear system and unpack the results.
+ with ops.name_scope('solve_linear_system'):
+ w_v = linalg_ops.matrix_solve(lhs, rhs)
+ w = w_v[:, :n, :]
+ v = w_v[:, n:, :]
+
+ return w, v
+
+
+def _apply_interpolation(query_points, train_points, w, v, order):
+ """Apply polyharmonic interpolation model to data.
+
+ Given coefficients w and v for the interpolation model, we evaluate
+ interpolated function values at query_points.
+
+ Args:
+ query_points: `[b, m, d]` x values to evaluate the interpolation at
+ train_points: `[b, n, d]` x values that act as the interpolation centers
+ ( the c variables in the wikipedia article)
+ w: `[b, n, k]` weights on each interpolation center
+ v: `[b, d, k]` weights on each input dimension
+ order: order of the interpolation
+
+ Returns:
+ Polyharmonic interpolation evaluated at points defined in query_points.
+ """
+
+ batch_size = train_points.get_shape()[0].value
+ num_query_points = query_points.get_shape()[1].value
+
+ # First, compute the contribution from the rbf term.
+ pairwise_dists = _cross_squared_distance_matrix(query_points, train_points)
+ phi_pairwise_dists = _phi(pairwise_dists, order)
+
+ rbf_term = math_ops.matmul(phi_pairwise_dists, w)
+
+ # Then, compute the contribution from the linear term.
+ # Pad query_points with ones, for the bias term in the linear model.
+ query_points_pad = array_ops.concat([
+ query_points,
+ array_ops.ones([batch_size, num_query_points, 1], train_points.dtype)
+ ], 2)
+ linear_term = math_ops.matmul(query_points_pad, v)
+
+ return rbf_term + linear_term
+
+
+def _phi(r, order):
+ """Coordinate-wise nonlinearity used to define the order of the interpolation.
+
+ See https://en.wikipedia.org/wiki/Polyharmonic_spline for the definition.
+
+ Args:
+ r: input op
+ order: interpolation order
+
+ Returns:
+ phi_k evaluated coordinate-wise on r, for k = r
+ """
+
+ # using EPSILON prevents log(0), sqrt0), etc.
+ # sqrt(0) is well-defined, but its gradient is not
+ with ops.name_scope('phi'):
+ if order == 1:
+ r = math_ops.maximum(r, EPSILON)
+ r = math_ops.sqrt(r)
+ return r
+ elif order == 2:
+ return 0.5 * r * math_ops.log(math_ops.maximum(r, EPSILON))
+ elif order == 4:
+ return 0.5 * math_ops.square(r) * math_ops.log(
+ math_ops.maximum(r, EPSILON))
+ elif order % 2 == 0:
+ r = math_ops.maximum(r, EPSILON)
+ return 0.5 * math_ops.pow(r, 0.5 * order) * math_ops.log(r)
+ else:
+ r = math_ops.maximum(r, EPSILON)
+ return math_ops.pow(r, 0.5 * order)
+
+
+def interpolate_spline(train_points,
+ train_values,
+ query_points,
+ order,
+ regularization_weight=0.0,
+ name='interpolate_spline'):
+ r"""Interpolate signal using polyharmonic interpolation.
+
+ The interpolant has the form
+ $$f(x) = \sum_{i = 1}^n w_i \phi(||x - c_i||) + v^T x + b.$$
+
+ This is a sum of two terms: (1) a weighted sum of radial basis function (RBF)
+ terms, with the centers \\(c_1, ... c_n\\), and (2) a linear term with a bias.
+ The \\(c_i\\) vectors are 'training' points. In the code, b is absorbed into v
+ by appending 1 as a final dimension to x. The coefficients w and v are
+ estimated such that the interpolant exactly fits the value of the function at
+ the \\(c_i\\) points, the vector w is orthogonal to each \\(c_i\\), and the
+ vector w sums to 0. With these constraints, the coefficients can be obtained
+ by solving a linear system.
+
+ \\(\phi\\) is an RBF, parametrized by an interpolation
+ order. Using order=2 produces the well-known thin-plate spline.
+
+ We also provide the option to perform regularized interpolation. Here, the
+ interpolant is selected to trade off between the squared loss on the training
+ data and a certain measure of its curvature
+ ([details](https://en.wikipedia.org/wiki/Polyharmonic_spline)).
+ Using a regularization weight greater than zero has the effect that the
+ interpolant will no longer exactly fit the training data. However, it may be
+ less vulnerable to overfitting, particularly for high-order interpolation.
+
+ Note the interpolation procedure is differentiable with respect to all inputs
+ besides the order parameter.
+
+ Args:
+ train_points: `[batch_size, n, d]` float `Tensor` of n d-dimensional
+ locations. These do not need to be regularly-spaced.
+ train_values: `[batch_size, n, k]` float `Tensor` of n c-dimensional values
+ evaluated at train_points.
+ query_points: `[batch_size, m, d]` `Tensor` of m d-dimensional locations
+ where we will output the interpolant's values.
+ order: order of the interpolation. Common values are 1 for
+ \\(\phi(r) = r\\), 2 for \\(\phi(r) = r^2 * log(r)\\) (thin-plate spline),
+ or 3 for \\(\phi(r) = r^3\\).
+ regularization_weight: weight placed on the regularization term.
+ This will depend substantially on the problem, and it should always be
+ tuned. For many problems, it is reasonable to use no regularization.
+ If using a non-zero value, we recommend a small value like 0.001.
+ name: name prefix for ops created by this function
+
+ Returns:
+ `[b, m, k]` float `Tensor` of query values. We use train_points and
+ train_values to perform polyharmonic interpolation. The query values are
+ the values of the interpolant evaluated at the locations specified in
+ query_points.
+ """
+ with ops.name_scope(name):
+
+ # First, fit the spline to the observed data.
+ with ops.name_scope('solve'):
+ w, v = _solve_interpolation(train_points, train_values, order,
+ regularization_weight)
+
+ # Then, evaluate the spline at the query locations.
+ with ops.name_scope('predict'):
+ query_values = _apply_interpolation(query_points, train_points, w, v,
+ order)
+
+ return query_values
diff --git a/tensorflow/contrib/image/python/ops/sparse_image_warp.py b/tensorflow/contrib/image/python/ops/sparse_image_warp.py
new file mode 100644
index 0000000000..9f50503d8f
--- /dev/null
+++ b/tensorflow/contrib/image/python/ops/sparse_image_warp.py
@@ -0,0 +1,192 @@
+# Copyright 2018 The TensorFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+"""Image warping using sparse flow defined at control points."""
+
+import numpy as np
+
+from tensorflow.contrib.image.python.ops import dense_image_warp
+from tensorflow.contrib.image.python.ops import interpolate_spline
+
+from tensorflow.python.framework import constant_op
+from tensorflow.python.framework import ops
+from tensorflow.python.ops import array_ops
+
+
+def _get_grid_locations(image_height, image_width):
+ """Wrapper for np.meshgrid."""
+
+ y_range = np.linspace(0, image_height - 1, image_height)
+ x_range = np.linspace(0, image_width - 1, image_width)
+ y_grid, x_grid = np.meshgrid(y_range, x_range, indexing='ij')
+ return np.stack((y_grid, x_grid), -1)
+
+
+def _expand_to_minibatch(np_array, batch_size):
+ """Tile arbitrarily-sized np_array to include new batch dimension."""
+ tiles = [batch_size] + [1] * np_array.ndim
+ return np.tile(np.expand_dims(np_array, 0), tiles)
+
+
+def _get_boundary_locations(image_height, image_width, num_points_per_edge):
+ """Compute evenly-spaced indices along edge of image."""
+ y_range = np.linspace(0, image_height - 1, num_points_per_edge + 2)
+ x_range = np.linspace(0, image_width - 1, num_points_per_edge + 2)
+ ys, xs = np.meshgrid(y_range, x_range, indexing='ij')
+ is_boundary = np.logical_or(
+ np.logical_or(xs == 0, xs == image_width - 1),
+ np.logical_or(ys == 0, ys == image_height - 1))
+ return np.stack([ys[is_boundary], xs[is_boundary]], axis=-1)
+
+
+def _add_zero_flow_controls_at_boundary(control_point_locations,
+ control_point_flows, image_height,
+ image_width, boundary_points_per_edge):
+ """Add control points for zero-flow boundary conditions.
+
+ Augment the set of control points with extra points on the
+ boundary of the image that have zero flow.
+
+ Args:
+ control_point_locations: input control points
+ control_point_flows: their flows
+ image_height: image height
+ image_width: image width
+ boundary_points_per_edge: number of points to add in the middle of each
+ edge (not including the corners).
+ The total number of points added is
+ 4 + 4*(boundary_points_per_edge).
+
+ Returns:
+ merged_control_point_locations: augmented set of control point locations
+ merged_control_point_flows: augmented set of control point flows
+ """
+
+ batch_size = control_point_locations.get_shape()[0].value
+
+ boundary_point_locations = _get_boundary_locations(image_height, image_width,
+ boundary_points_per_edge)
+
+ boundary_point_flows = np.zeros([boundary_point_locations.shape[0], 2])
+
+ type_to_use = control_point_locations.dtype
+ boundary_point_locations = constant_op.constant(
+ _expand_to_minibatch(boundary_point_locations, batch_size),
+ dtype=type_to_use)
+
+ boundary_point_flows = constant_op.constant(
+ _expand_to_minibatch(boundary_point_flows, batch_size), dtype=type_to_use)
+
+ merged_control_point_locations = array_ops.concat(
+ [control_point_locations, boundary_point_locations], 1)
+
+ merged_control_point_flows = array_ops.concat(
+ [control_point_flows, boundary_point_flows], 1)
+
+ return merged_control_point_locations, merged_control_point_flows
+
+
+def sparse_image_warp(image,
+ source_control_point_locations,
+ dest_control_point_locations,
+ interpolation_order=2,
+ regularization_weight=0.0,
+ num_boundary_points=0,
+ name='sparse_image_warp'):
+ """Image warping using correspondences between sparse control points.
+
+ Apply a non-linear warp to the image, where the warp is specified by
+ the source and destination locations of a (potentially small) number of
+ control points. First, we use a polyharmonic spline
+ (@{tf.contrib.image.interpolate_spline}) to interpolate the displacements
+ between the corresponding control points to a dense flow field.
+ Then, we warp the image using this dense flow field
+ (@{tf.contrib.image.dense_image_warp}).
+
+ Let t index our control points. For regularization_weight=0, we have:
+ warped_image[b, dest_control_point_locations[b, t, 0],
+ dest_control_point_locations[b, t, 1], :] =
+ image[b, source_control_point_locations[b, t, 0],
+ source_control_point_locations[b, t, 1], :].
+
+ For regularization_weight > 0, this condition is met approximately, since
+ regularized interpolation trades off smoothness of the interpolant vs.
+ reconstruction of the interpolant at the control points.
+ See @{tf.contrib.image.interpolate_spline} for further documentation of the
+ interpolation_order and regularization_weight arguments.
+
+
+ Args:
+ image: `[batch, height, width, channels]` float `Tensor`
+ source_control_point_locations: `[batch, num_control_points, 2]` float
+ `Tensor`
+ dest_control_point_locations: `[batch, num_control_points, 2]` float
+ `Tensor`
+ interpolation_order: polynomial order used by the spline interpolation
+ regularization_weight: weight on smoothness regularizer in interpolation
+ num_boundary_points: How many zero-flow boundary points to include at
+ each image edge.Usage:
+ num_boundary_points=0: don't add zero-flow points
+ num_boundary_points=1: 4 corners of the image
+ num_boundary_points=2: 4 corners and one in the middle of each edge
+ (8 points total)
+ num_boundary_points=n: 4 corners and n-1 along each edge
+ name: A name for the operation (optional).
+
+ Note that image and offsets can be of type tf.half, tf.float32, or
+ tf.float64, and do not necessarily have to be the same type.
+
+ Returns:
+ warped_image: `[batch, height, width, channels]` float `Tensor` with same
+ type as input image.
+ flow_field: `[batch, height, width, 2]` float `Tensor` containing the dense
+ flow field produced by the interpolation.
+ """
+
+ control_point_flows = (
+ dest_control_point_locations - source_control_point_locations)
+
+ clamp_boundaries = num_boundary_points > 0
+ boundary_points_per_edge = num_boundary_points - 1
+
+ with ops.name_scope(name):
+
+ batch_size, image_height, image_width, _ = image.get_shape().as_list()
+
+ # This generates the dense locations where the interpolant
+ # will be evaluated.
+ grid_locations = _get_grid_locations(image_height, image_width)
+
+ flattened_grid_locations = np.reshape(grid_locations,
+ [image_height * image_width, 2])
+
+ flattened_grid_locations = constant_op.constant(
+ _expand_to_minibatch(flattened_grid_locations, batch_size), image.dtype)
+
+ if clamp_boundaries:
+ (dest_control_point_locations,
+ control_point_flows) = _add_zero_flow_controls_at_boundary(
+ dest_control_point_locations, control_point_flows, image_height,
+ image_width, boundary_points_per_edge)
+
+ flattened_flows = interpolate_spline.interpolate_spline(
+ dest_control_point_locations, control_point_flows,
+ flattened_grid_locations, interpolation_order, regularization_weight)
+
+ dense_flows = array_ops.reshape(flattened_flows,
+ [batch_size, image_height, image_width, 2])
+
+ warped_image = dense_image_warp.dense_image_warp(image, dense_flows)
+
+ return warped_image, dense_flows