aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/contrib/image/python/kernel_tests
diff options
context:
space:
mode:
Diffstat (limited to 'tensorflow/contrib/image/python/kernel_tests')
-rw-r--r--tensorflow/contrib/image/python/kernel_tests/dense_image_warp_test.py267
-rw-r--r--tensorflow/contrib/image/python/kernel_tests/interpolate_spline_test.py264
-rw-r--r--tensorflow/contrib/image/python/kernel_tests/sparse_image_warp_test.py254
-rw-r--r--tensorflow/contrib/image/python/kernel_tests/test_data/Yellow_Smiley_Face.pngbin0 -> 14060 bytes
-rw-r--r--tensorflow/contrib/image/python/kernel_tests/test_data/Yellow_Smiley_Face_Warp-interp-1-clamp-0.pngbin0 -> 18537 bytes
-rw-r--r--tensorflow/contrib/image/python/kernel_tests/test_data/Yellow_Smiley_Face_Warp-interp-1-clamp-1.pngbin0 -> 19086 bytes
-rw-r--r--tensorflow/contrib/image/python/kernel_tests/test_data/Yellow_Smiley_Face_Warp-interp-1-clamp-4.pngbin0 -> 18884 bytes
-rw-r--r--tensorflow/contrib/image/python/kernel_tests/test_data/Yellow_Smiley_Face_Warp-interp-2-clamp-0.pngbin0 -> 18109 bytes
-rw-r--r--tensorflow/contrib/image/python/kernel_tests/test_data/Yellow_Smiley_Face_Warp-interp-2-clamp-1.pngbin0 -> 19251 bytes
-rw-r--r--tensorflow/contrib/image/python/kernel_tests/test_data/Yellow_Smiley_Face_Warp-interp-2-clamp-4.pngbin0 -> 19132 bytes
-rw-r--r--tensorflow/contrib/image/python/kernel_tests/test_data/Yellow_Smiley_Face_Warp-interp-3-clamp-0.pngbin0 -> 17500 bytes
-rw-r--r--tensorflow/contrib/image/python/kernel_tests/test_data/Yellow_Smiley_Face_Warp-interp-3-clamp-1.pngbin0 -> 18058 bytes
-rw-r--r--tensorflow/contrib/image/python/kernel_tests/test_data/Yellow_Smiley_Face_Warp-interp-3-clamp-4.pngbin0 -> 19313 bytes
13 files changed, 785 insertions, 0 deletions
diff --git a/tensorflow/contrib/image/python/kernel_tests/dense_image_warp_test.py b/tensorflow/contrib/image/python/kernel_tests/dense_image_warp_test.py
new file mode 100644
index 0000000000..a58b6a247e
--- /dev/null
+++ b/tensorflow/contrib/image/python/kernel_tests/dense_image_warp_test.py
@@ -0,0 +1,267 @@
+# Copyright 2018 The TensorFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+"""Tests for dense_image_warp."""
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+import math
+import numpy as np
+
+from tensorflow.contrib.image.python.ops import dense_image_warp
+
+from tensorflow.python.framework import constant_op
+from tensorflow.python.framework import dtypes
+
+from tensorflow.python.framework import test_util
+from tensorflow.python.ops import array_ops
+from tensorflow.python.ops import gradients
+from tensorflow.python.ops import math_ops
+from tensorflow.python.ops import random_ops
+from tensorflow.python.ops import variables
+from tensorflow.python.platform import googletest
+
+from tensorflow.python.training import adam
+
+
+class DenseImageWarpTest(test_util.TensorFlowTestCase):
+
+ def setUp(self):
+ np.random.seed(0)
+
+ def test_interpolate_small_grid_ij(self):
+ grid = constant_op.constant(
+ [[0., 1., 2.], [3., 4., 5.], [6., 7., 8.]], shape=[1, 3, 3, 1])
+ query_points = constant_op.constant(
+ [[0., 0.], [1., 0.], [2., 0.5], [1.5, 1.5]], shape=[1, 4, 2])
+ expected_results = np.reshape(np.array([0., 3., 6.5, 6.]), [1, 4, 1])
+
+ interp = dense_image_warp._interpolate_bilinear(grid, query_points)
+
+ with self.test_session() as sess:
+ predicted = sess.run(interp)
+ self.assertAllClose(expected_results, predicted)
+
+ def test_interpolate_small_grid_xy(self):
+ grid = constant_op.constant(
+ [[0., 1., 2.], [3., 4., 5.], [6., 7., 8.]], shape=[1, 3, 3, 1])
+ query_points = constant_op.constant(
+ [[0., 0.], [0., 1.], [0.5, 2.0], [1.5, 1.5]], shape=[1, 4, 2])
+ expected_results = np.reshape(np.array([0., 3., 6.5, 6.]), [1, 4, 1])
+
+ interp = dense_image_warp._interpolate_bilinear(
+ grid, query_points, indexing='xy')
+
+ with self.test_session() as sess:
+ predicted = sess.run(interp)
+ self.assertAllClose(expected_results, predicted)
+
+ def test_interpolate_small_grid_batched(self):
+ grid = constant_op.constant(
+ [[[0., 1.], [3., 4.]], [[5., 6.], [7., 8.]]], shape=[2, 2, 2, 1])
+ query_points = constant_op.constant([[[0., 0.], [1., 0.], [0.5, 0.5]],
+ [[0.5, 0.], [1., 0.], [1., 1.]]])
+ expected_results = np.reshape(
+ np.array([[0., 3., 2.], [6., 7., 8.]]), [2, 3, 1])
+
+ interp = dense_image_warp._interpolate_bilinear(grid, query_points)
+
+ with self.test_session() as sess:
+ predicted = sess.run(interp)
+ self.assertAllClose(expected_results, predicted)
+
+ def get_image_and_flow_placeholders(self, shape, image_type, flow_type):
+ batch_size, height, width, numchannels = shape
+ image_shape = [batch_size, height, width, numchannels]
+ flow_shape = [batch_size, height, width, 2]
+
+ tf_type = {
+ 'float16': dtypes.half,
+ 'float32': dtypes.float32,
+ 'float64': dtypes.float64
+ }
+
+ image = array_ops.placeholder(dtype=tf_type[image_type], shape=image_shape)
+
+ flows = array_ops.placeholder(dtype=tf_type[flow_type], shape=flow_shape)
+ return image, flows
+
+ def get_random_image_and_flows(self, shape, image_type, flow_type):
+ batch_size, height, width, numchannels = shape
+ image_shape = [batch_size, height, width, numchannels]
+ image = np.random.normal(size=image_shape)
+ flow_shape = [batch_size, height, width, 2]
+ flows = np.random.normal(size=flow_shape) * 3
+ return image.astype(image_type), flows.astype(flow_type)
+
+ def assert_correct_interpolation_value(self,
+ image,
+ flows,
+ pred_interpolation,
+ batch_index,
+ y_index,
+ x_index,
+ low_precision=False):
+ """Assert that the tf interpolation matches hand-computed value."""
+
+ height = image.shape[1]
+ width = image.shape[2]
+ displacement = flows[batch_index, y_index, x_index, :]
+ float_y = y_index - displacement[0]
+ float_x = x_index - displacement[1]
+ floor_y = max(min(height - 2, math.floor(float_y)), 0)
+ floor_x = max(min(width - 2, math.floor(float_x)), 0)
+ ceil_y = floor_y + 1
+ ceil_x = floor_x + 1
+
+ alpha_y = min(max(0.0, float_y - floor_y), 1.0)
+ alpha_x = min(max(0.0, float_x - floor_x), 1.0)
+
+ floor_y = int(floor_y)
+ floor_x = int(floor_x)
+ ceil_y = int(ceil_y)
+ ceil_x = int(ceil_x)
+
+ top_left = image[batch_index, floor_y, floor_x, :]
+ top_right = image[batch_index, floor_y, ceil_x, :]
+ bottom_left = image[batch_index, ceil_y, floor_x, :]
+ bottom_right = image[batch_index, ceil_y, ceil_x, :]
+
+ interp_top = alpha_x * (top_right - top_left) + top_left
+ interp_bottom = alpha_x * (bottom_right - bottom_left) + bottom_left
+ interp = alpha_y * (interp_bottom - interp_top) + interp_top
+ atol = 1e-6
+ rtol = 1e-6
+ if low_precision:
+ atol = 1e-2
+ rtol = 1e-3
+ self.assertAllClose(
+ interp,
+ pred_interpolation[batch_index, y_index, x_index, :],
+ atol=atol,
+ rtol=rtol)
+
+ def check_zero_flow_correctness(self, shape, image_type, flow_type):
+ """Assert using zero flows doesn't change the input image."""
+
+ image, flows = self.get_image_and_flow_placeholders(shape, image_type,
+ flow_type)
+ interp = dense_image_warp.dense_image_warp(image, flows)
+
+ with self.test_session() as sess:
+ rand_image, rand_flows = self.get_random_image_and_flows(
+ shape, image_type, flow_type)
+ rand_flows *= 0
+
+ predicted_interpolation = sess.run(
+ interp, feed_dict={
+ image: rand_image,
+ flows: rand_flows
+ })
+ self.assertAllClose(rand_image, predicted_interpolation)
+
+ def test_zero_flows(self):
+ """Apply check_zero_flow_correctness() for a few sizes and types."""
+
+ shapes_to_try = [[3, 4, 5, 6], [1, 2, 2, 1]]
+ for shape in shapes_to_try:
+ self.check_zero_flow_correctness(
+ shape, image_type='float32', flow_type='float32')
+
+ def check_interpolation_correctness(self,
+ shape,
+ image_type,
+ flow_type,
+ num_probes=5):
+ """Interpolate, and then assert correctness for a few query locations."""
+
+ image, flows = self.get_image_and_flow_placeholders(shape, image_type,
+ flow_type)
+ interp = dense_image_warp.dense_image_warp(image, flows)
+ low_precision = image_type == 'float16' or flow_type == 'float16'
+ with self.test_session() as sess:
+ rand_image, rand_flows = self.get_random_image_and_flows(
+ shape, image_type, flow_type)
+
+ pred_interpolation = sess.run(
+ interp, feed_dict={
+ image: rand_image,
+ flows: rand_flows
+ })
+
+ for _ in range(num_probes):
+ batch_index = np.random.randint(0, shape[0])
+ y_index = np.random.randint(0, shape[1])
+ x_index = np.random.randint(0, shape[2])
+
+ self.assert_correct_interpolation_value(
+ rand_image,
+ rand_flows,
+ pred_interpolation,
+ batch_index,
+ y_index,
+ x_index,
+ low_precision=low_precision)
+
+ def test_interpolation(self):
+ """Apply check_interpolation_correctness() for a few sizes and types."""
+
+ shapes_to_try = [[3, 4, 5, 6], [1, 5, 5, 3], [1, 2, 2, 1]]
+ for im_type in ['float32', 'float64', 'float16']:
+ for flow_type in ['float32', 'float64', 'float16']:
+ for shape in shapes_to_try:
+ self.check_interpolation_correctness(shape, im_type, flow_type)
+
+ def test_gradients_exist(self):
+ """Check that backprop can run.
+
+ The correctness of the gradients is assumed, since the forward propagation
+ is tested to be correct and we only use built-in tf ops.
+ However, we perform a simple test to make sure that backprop can actually
+ run. We treat the flows as a tf.Variable and optimize them to minimize
+ the difference between the interpolated image and the input image.
+ """
+
+ batch_size, height, width, numchannels = [4, 5, 6, 7]
+ image_shape = [batch_size, height, width, numchannels]
+ image = random_ops.random_normal(image_shape)
+ flow_shape = [batch_size, height, width, 2]
+ init_flows = np.float32(np.random.normal(size=flow_shape) * 0.25)
+ flows = variables.Variable(init_flows)
+
+ interp = dense_image_warp.dense_image_warp(image, flows)
+ loss = math_ops.reduce_mean(math_ops.square(interp - image))
+
+ optimizer = adam.AdamOptimizer(1.0)
+ grad = gradients.gradients(loss, [flows])
+ opt_func = optimizer.apply_gradients(zip(grad, [flows]))
+ init_op = variables.global_variables_initializer()
+
+ with self.test_session() as sess:
+ sess.run(init_op)
+ for _ in range(10):
+ sess.run(opt_func)
+
+ def test_size_exception(self):
+ """Make sure it throws an exception for images that are too small."""
+
+ shape = [1, 2, 1, 1]
+ msg = 'Should have raised an exception for invalid image size'
+ with self.assertRaises(ValueError, msg=msg):
+ self.check_interpolation_correctness(shape, 'float32', 'float32')
+
+
+if __name__ == '__main__':
+ googletest.main()
diff --git a/tensorflow/contrib/image/python/kernel_tests/interpolate_spline_test.py b/tensorflow/contrib/image/python/kernel_tests/interpolate_spline_test.py
new file mode 100644
index 0000000000..1939caaa2d
--- /dev/null
+++ b/tensorflow/contrib/image/python/kernel_tests/interpolate_spline_test.py
@@ -0,0 +1,264 @@
+# Copyright 2018 The TensorFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+"""Tests for interpolate_spline."""
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+import numpy as np
+from scipy import interpolate as sc_interpolate
+
+from tensorflow.contrib.image.python.ops import interpolate_spline
+
+from tensorflow.python.framework import constant_op
+from tensorflow.python.framework import ops
+from tensorflow.python.framework import test_util
+
+from tensorflow.python.ops import clip_ops
+from tensorflow.python.ops import gradients
+from tensorflow.python.ops import math_ops
+from tensorflow.python.ops import variables
+from tensorflow.python.platform import googletest
+
+from tensorflow.python.training import momentum
+
+
+class _InterpolationProblem(object):
+ """Abstract class for interpolation problem descriptions."""
+
+ def get_problem(self, optimizable=False, extrapolate=True, dtype='float32'):
+ """Make data for an interpolation problem where all x vectors are n-d.
+
+ Args:
+ optimizable: If True, then make train_points a tf.Variable.
+ extrapolate: If False, then clamp the query_points values to be within
+ the max and min of train_points.
+ dtype: The data type to use.
+
+ Returns:
+ query_points, query_values, train_points, train_values: training and
+ test tensors for interpolation problem
+ """
+
+ # The values generated here depend on a seed of 0.
+ np.random.seed(0)
+
+ batch_size = 1
+ num_training_points = 10
+ num_query_points = 4
+
+ init_points = np.random.uniform(
+ size=[batch_size, num_training_points, self.DATA_DIM])
+
+ init_points = init_points.astype(dtype)
+ train_points = (
+ variables.Variable(init_points)
+ if optimizable else constant_op.constant(init_points))
+ train_values = self.tf_function(train_points)
+
+ query_points_np = np.random.uniform(
+ size=[batch_size, num_query_points, self.DATA_DIM])
+ query_points_np = query_points_np.astype(dtype)
+ if not extrapolate:
+ query_points_np = np.clip(query_points_np, np.min(init_points),
+ np.max(init_points))
+
+ query_points = constant_op.constant(query_points_np)
+ query_values = self.np_function(query_points_np)
+
+ return query_points, query_values, train_points, train_values
+
+
+class _QuadraticPlusSinProblem1D(_InterpolationProblem):
+ """1D interpolation problem used for regression testing."""
+ DATA_DIM = 1
+ HARDCODED_QUERY_VALUES = {
+ (1.0, 0.0): [6.2647187603, -7.84362604077, -5.63690142322, 1.42928896387],
+ (1.0,
+ 0.01): [6.77688289946, -8.02163669853, -5.79491157027, 1.4063285693],
+ (2.0,
+ 0.0): [8.67110264937, -8.41281390883, -5.80190044693, 1.50155606059],
+ (2.0,
+ 0.01): [6.70797816797, -7.49709587663, -5.28965776238, 1.52284731741],
+ (3.0,
+ 0.0): [9.37691802935, -8.50390141515, -5.80786417426, 1.63467762122],
+ (3.0,
+ 0.01): [4.47106304758, -5.71266128361, -3.92529303296, 1.86755293857],
+ (4.0,
+ 0.0): [9.58172461111, -8.51432104771, -5.80967675388, 1.63361164256],
+ (4.0, 0.01): [
+ -3.87902711352, -0.0253462273846, 1.79857618022, -0.769339675725
+ ]
+ }
+
+ def np_function(self, x):
+ """Takes np array, evaluates the test function, and returns np array."""
+ return np.sum(
+ np.power((x - 0.5), 3) - 0.25 * x + 10 * np.sin(x * 10),
+ axis=2,
+ keepdims=True)
+
+ def tf_function(self, x):
+ """Takes tf tensor, evaluates the test function, and returns tf tensor."""
+ return math_ops.reduce_mean(
+ math_ops.pow((x - 0.5), 3) - 0.25 * x + 10 * math_ops.sin(x * 10),
+ 2,
+ keepdims=True)
+
+
+class _QuadraticPlusSinProblemND(_InterpolationProblem):
+ """3D interpolation problem used for regression testing."""
+
+ DATA_DIM = 3
+ HARDCODED_QUERY_VALUES = {
+ (1.0, 0.0): [1.06609663962, 1.28894849357, 1.10882405595, 1.63966936885],
+ (1.0, 0.01): [1.03123780748, 1.2952930985, 1.10366822954, 1.65265118569],
+ (2.0, 0.0): [0.627787735064, 1.43802857251, 1.00194632358, 1.91667538215],
+ (2.0, 0.01): [0.730159985046, 1.41702471595, 1.0065827217, 1.85758519312],
+ (3.0, 0.0): [0.350460417862, 1.67223539464, 1.00475331246, 2.31580322491],
+ (3.0,
+ 0.01): [0.624557250556, 1.63138876667, 0.976588193162, 2.12511237866],
+ (4.0,
+ 0.0): [0.898129669986, 1.24434133638, -0.938056116931, 1.59910338833],
+ (4.0,
+ 0.01): [0.0930360338179, -3.38791305538, -1.00969032567, 0.745535080382],
+ }
+
+ def np_function(self, x):
+ """Takes np array, evaluates the test function, and returns np array."""
+ return np.sum(
+ np.square(x - 0.5) + 0.25 * x + 1 * np.sin(x * 15),
+ axis=2,
+ keepdims=True)
+
+ def tf_function(self, x):
+ """Takes tf tensor, evaluates the test function, and returns tf tensor."""
+ return math_ops.reduce_sum(
+ math_ops.square(x - 0.5) + 0.25 * x + 1 * math_ops.sin(x * 15),
+ 2,
+ keepdims=True)
+
+
+class InterpolateSplineTest(test_util.TensorFlowTestCase):
+
+ def test_1d_linear_interpolation(self):
+ """For 1d linear interpolation, we can compare directly to scipy."""
+
+ tp = _QuadraticPlusSinProblem1D()
+ (query_points, _, train_points, train_values) = tp.get_problem(
+ extrapolate=False, dtype='float64')
+ interpolation_order = 1
+
+ with ops.name_scope('interpolator'):
+ interpolator = interpolate_spline.interpolate_spline(
+ train_points, train_values, query_points, interpolation_order)
+ with self.test_session() as sess:
+ fetches = [query_points, train_points, train_values, interpolator]
+ query_points_, train_points_, train_values_, interp_ = sess.run(fetches)
+
+ # Just look at the first element of the minibatch.
+ # Also, trim the final singleton dimension.
+ interp_ = interp_[0, :, 0]
+ query_points_ = query_points_[0, :, 0]
+ train_points_ = train_points_[0, :, 0]
+ train_values_ = train_values_[0, :, 0]
+
+ # Compute scipy interpolation.
+ scipy_interp_function = sc_interpolate.interp1d(
+ train_points_, train_values_, kind='linear')
+
+ scipy_interpolation = scipy_interp_function(query_points_)
+ scipy_interpolation_on_train = scipy_interp_function(train_points_)
+
+ # Even with float64 precision, the interpolants disagree with scipy a
+ # bit due to the fact that we add the EPSILON to prevent sqrt(0), etc.
+ tol = 1e-3
+
+ self.assertAllClose(
+ train_values_, scipy_interpolation_on_train, atol=tol, rtol=tol)
+ self.assertAllClose(interp_, scipy_interpolation, atol=tol, rtol=tol)
+
+ def test_1d_interpolation(self):
+ """Regression test for interpolation with 1-D points."""
+
+ tp = _QuadraticPlusSinProblem1D()
+ (query_points, _, train_points,
+ train_values) = tp.get_problem(dtype='float64')
+
+ for order in (1, 2, 3):
+ for reg_weight in (0, 0.01):
+ interpolator = interpolate_spline.interpolate_spline(
+ train_points, train_values, query_points, order, reg_weight)
+
+ target_interpolation = tp.HARDCODED_QUERY_VALUES[(order, reg_weight)]
+ target_interpolation = np.array(target_interpolation)
+ with self.test_session() as sess:
+ interp_val = sess.run(interpolator)
+ self.assertAllClose(interp_val[0, :, 0], target_interpolation)
+
+ def test_nd_linear_interpolation(self):
+ """Regression test for interpolation with N-D points."""
+
+ tp = _QuadraticPlusSinProblemND()
+ (query_points, _, train_points,
+ train_values) = tp.get_problem(dtype='float64')
+
+ for order in (1, 2, 3):
+ for reg_weight in (0, 0.01):
+ interpolator = interpolate_spline.interpolate_spline(
+ train_points, train_values, query_points, order, reg_weight)
+
+ target_interpolation = tp.HARDCODED_QUERY_VALUES[(order, reg_weight)]
+ target_interpolation = np.array(target_interpolation)
+ with self.test_session() as sess:
+ interp_val = sess.run(interpolator)
+ self.assertAllClose(interp_val[0, :, 0], target_interpolation)
+
+ def test_interpolation_gradient(self):
+ """Make sure that backprop can run. Correctness of gradients is assumed.
+
+ Here, we create a use a small 'training' set and a more densely-sampled
+ set of query points, for which we know the true value in advance. The goal
+ is to choose x locations for the training data such that interpolating using
+ this training data yields the best reconstruction for the function
+ values at the query points. The training data locations are optimized
+ iteratively using gradient descent.
+ """
+ tp = _QuadraticPlusSinProblemND()
+ (query_points, query_values, train_points,
+ train_values) = tp.get_problem(optimizable=True)
+
+ regularization = 0.001
+ for interpolation_order in (1, 2, 3, 4):
+ interpolator = interpolate_spline.interpolate_spline(
+ train_points, train_values, query_points, interpolation_order,
+ regularization)
+
+ loss = math_ops.reduce_mean(math_ops.square(query_values - interpolator))
+
+ optimizer = momentum.MomentumOptimizer(0.001, 0.9)
+ grad = gradients.gradients(loss, [train_points])
+ grad, _ = clip_ops.clip_by_global_norm(grad, 1.0)
+ opt_func = optimizer.apply_gradients(zip(grad, [train_points]))
+ init_op = variables.global_variables_initializer()
+
+ with self.test_session() as sess:
+ sess.run(init_op)
+ for _ in range(100):
+ sess.run([loss, opt_func])
+
+
+if __name__ == '__main__':
+ googletest.main()
diff --git a/tensorflow/contrib/image/python/kernel_tests/sparse_image_warp_test.py b/tensorflow/contrib/image/python/kernel_tests/sparse_image_warp_test.py
new file mode 100644
index 0000000000..0135c66e29
--- /dev/null
+++ b/tensorflow/contrib/image/python/kernel_tests/sparse_image_warp_test.py
@@ -0,0 +1,254 @@
+# Copyright 2018 The TensorFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+"""Tests for sparse_image_warp."""
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+import numpy as np
+
+from tensorflow.contrib.image.python.ops import sparse_image_warp
+
+from tensorflow.python.framework import constant_op
+from tensorflow.python.framework import dtypes
+from tensorflow.python.framework import test_util
+from tensorflow.python.ops import clip_ops
+from tensorflow.python.ops import gradients
+from tensorflow.python.ops import image_ops
+from tensorflow.python.ops import io_ops
+from tensorflow.python.ops import math_ops
+from tensorflow.python.ops import variables
+from tensorflow.python.platform import googletest
+from tensorflow.python.platform import test
+
+from tensorflow.python.training import momentum
+
+
+class SparseImageWarpTest(test_util.TensorFlowTestCase):
+
+ def setUp(self):
+ np.random.seed(0)
+
+ def testGetBoundaryLocations(self):
+ image_height = 11
+ image_width = 11
+ num_points_per_edge = 4
+ locs = sparse_image_warp._get_boundary_locations(image_height, image_width,
+ num_points_per_edge)
+ num_points = locs.shape[0]
+ self.assertEqual(num_points, 4 + 4 * num_points_per_edge)
+ locs = [(locs[i, 0], locs[i, 1]) for i in range(num_points)]
+ for i in (0, image_height - 1):
+ for j in (0, image_width - 1):
+ self.assertIn((i, j), locs, '{},{} not in the locations'.format(i, j))
+
+ for i in (2, 4, 6, 8):
+ for j in (0, image_width - 1):
+ self.assertIn((i, j), locs, '{},{} not in the locations'.format(i, j))
+
+ for i in (0, image_height - 1):
+ for j in (2, 4, 6, 8):
+ self.assertIn((i, j), locs, '{},{} not in the locations'.format(i, j))
+
+ def testGetGridLocations(self):
+ image_height = 5
+ image_width = 3
+ grid = sparse_image_warp._get_grid_locations(image_height, image_width)
+ for i in range(image_height):
+ for j in range(image_width):
+ self.assertEqual(grid[i, j, 0], i)
+ self.assertEqual(grid[i, j, 1], j)
+
+ def testZeroShift(self):
+ """Run assertZeroShift for various hyperparameters."""
+ for order in (1, 2):
+ for regularization in (0, 0.01):
+ for num_boundary_points in (0, 1):
+ self.assertZeroShift(order, regularization, num_boundary_points)
+
+ def assertZeroShift(self, order, regularization, num_boundary_points):
+ """Check that warping with zero displacements doesn't change the image."""
+ batch_size = 1
+ image_height = 4
+ image_width = 4
+ channels = 3
+
+ image = np.random.uniform(
+ size=[batch_size, image_height, image_width, channels])
+
+ input_image_op = constant_op.constant(np.float32(image))
+
+ control_point_locations = [[1., 1.], [2., 2.], [2., 1.]]
+ control_point_locations = constant_op.constant(
+ np.float32(np.expand_dims(control_point_locations, 0)))
+
+ control_point_displacements = np.zeros(
+ control_point_locations.shape.as_list())
+ control_point_displacements = constant_op.constant(
+ np.float32(control_point_displacements))
+
+ (warped_image_op, flow_field) = sparse_image_warp.sparse_image_warp(
+ input_image_op,
+ control_point_locations,
+ control_point_locations + control_point_displacements,
+ interpolation_order=order,
+ regularization_weight=regularization,
+ num_boundary_points=num_boundary_points)
+
+ with self.test_session() as sess:
+ warped_image, input_image, _ = sess.run(
+ [warped_image_op, input_image_op, flow_field])
+
+ self.assertAllClose(warped_image, input_image)
+
+ def testMoveSinglePixel(self):
+ """Run assertMoveSinglePixel for various hyperparameters and data types."""
+ for order in (1, 2):
+ for num_boundary_points in (1, 2):
+ for type_to_use in (dtypes.float32, dtypes.float64):
+ self.assertMoveSinglePixel(order, num_boundary_points, type_to_use)
+
+ def assertMoveSinglePixel(self, order, num_boundary_points, type_to_use):
+ """Move a single block in a small grid using warping."""
+ batch_size = 1
+ image_height = 7
+ image_width = 7
+ channels = 3
+
+ image = np.zeros([batch_size, image_height, image_width, channels])
+ image[:, 3, 3, :] = 1.0
+ input_image_op = constant_op.constant(image, dtype=type_to_use)
+
+ # Place a control point at the one white pixel.
+ control_point_locations = [[3., 3.]]
+ control_point_locations = constant_op.constant(
+ np.float32(np.expand_dims(control_point_locations, 0)),
+ dtype=type_to_use)
+ # Shift it one pixel to the right.
+ control_point_displacements = [[0., 1.0]]
+ control_point_displacements = constant_op.constant(
+ np.float32(np.expand_dims(control_point_displacements, 0)),
+ dtype=type_to_use)
+
+ (warped_image_op, flow_field) = sparse_image_warp.sparse_image_warp(
+ input_image_op,
+ control_point_locations,
+ control_point_locations + control_point_displacements,
+ interpolation_order=order,
+ num_boundary_points=num_boundary_points)
+
+ with self.test_session() as sess:
+ warped_image, input_image, flow = sess.run(
+ [warped_image_op, input_image_op, flow_field])
+ # Check that it moved the pixel correctly.
+ self.assertAllClose(
+ warped_image[0, 4, 5, :],
+ input_image[0, 4, 4, :],
+ atol=1e-5,
+ rtol=1e-5)
+
+ # Test that there is no flow at the corners.
+ for i in (0, image_height - 1):
+ for j in (0, image_width - 1):
+ self.assertAllClose(
+ flow[0, i, j, :], np.zeros([2]), atol=1e-5, rtol=1e-5)
+
+ def load_image(self, image_file, sess):
+ image_op = image_ops.decode_png(
+ io_ops.read_file(image_file), dtype=dtypes.uint8, channels=4)[:, :, 0:3]
+ return sess.run(image_op)
+
+ def testSmileyFace(self):
+ """Check warping accuracy by comparing to hardcoded warped images."""
+
+ test_data_dir = test.test_src_dir_path('contrib/image/python/'
+ 'kernel_tests/test_data/')
+ input_file = test_data_dir + 'Yellow_Smiley_Face.png'
+ with self.test_session() as sess:
+ input_image = self.load_image(input_file, sess)
+ control_points = np.asarray([[64, 59], [180 - 64, 59], [39, 111],
+ [180 - 39, 111], [90, 143], [58, 134],
+ [180 - 58, 134]]) # pyformat: disable
+ control_point_displacements = np.asarray(
+ [[-10.5, 10.5], [10.5, 10.5], [0, 0], [0, 0], [0, -10], [-20, 10.25],
+ [10, 10.75]])
+ control_points_op = constant_op.constant(
+ np.expand_dims(np.float32(control_points[:, [1, 0]]), 0))
+ control_point_displacements_op = constant_op.constant(
+ np.expand_dims(np.float32(control_point_displacements[:, [1, 0]]), 0))
+ float_image = np.expand_dims(np.float32(input_image) / 255, 0)
+ input_image_op = constant_op.constant(float_image)
+
+ for interpolation_order in (1, 2, 3):
+ for num_boundary_points in (0, 1, 4):
+ warp_op, _ = sparse_image_warp.sparse_image_warp(
+ input_image_op,
+ control_points_op,
+ control_points_op + control_point_displacements_op,
+ interpolation_order=interpolation_order,
+ num_boundary_points=num_boundary_points)
+ with self.test_session() as sess:
+ warped_image = sess.run(warp_op)
+ out_image = np.uint8(warped_image[0, :, :, :] * 255)
+ target_file = (
+ test_data_dir +
+ 'Yellow_Smiley_Face_Warp-interp' + '-{}-clamp-{}.png'.format(
+ interpolation_order, num_boundary_points))
+
+ target_image = self.load_image(target_file, sess)
+
+ # Check that the target_image and out_image difference is no
+ # bigger than 2 (on a scale of 0-255). Due to differences in
+ # floating point computation on different devices, the float
+ # output in warped_image may get rounded to a different int
+ # than that in the saved png file loaded into target_image.
+ self.assertAllClose(target_image, out_image, atol=2, rtol=1e-3)
+
+ def testThatBackpropRuns(self):
+ """Run optimization to ensure that gradients can be computed."""
+
+ batch_size = 1
+ image_height = 9
+ image_width = 12
+ image = variables.Variable(
+ np.float32(
+ np.random.uniform(size=[batch_size, image_height, image_width, 3])))
+ control_point_locations = [[3., 3.]]
+ control_point_locations = constant_op.constant(
+ np.float32(np.expand_dims(control_point_locations, 0)))
+ control_point_displacements = [[0.25, -0.5]]
+ control_point_displacements = constant_op.constant(
+ np.float32(np.expand_dims(control_point_displacements, 0)))
+ warped_image, _ = sparse_image_warp.sparse_image_warp(
+ image,
+ control_point_locations,
+ control_point_locations + control_point_displacements,
+ num_boundary_points=3)
+
+ loss = math_ops.reduce_mean(math_ops.abs(warped_image - image))
+ optimizer = momentum.MomentumOptimizer(0.001, 0.9)
+ grad = gradients.gradients(loss, [image])
+ grad, _ = clip_ops.clip_by_global_norm(grad, 1.0)
+ opt_func = optimizer.apply_gradients(zip(grad, [image]))
+ init_op = variables.global_variables_initializer()
+
+ with self.test_session() as sess:
+ sess.run(init_op)
+ for _ in range(5):
+ sess.run([loss, opt_func])
+
+
+if __name__ == '__main__':
+ googletest.main()
diff --git a/tensorflow/contrib/image/python/kernel_tests/test_data/Yellow_Smiley_Face.png b/tensorflow/contrib/image/python/kernel_tests/test_data/Yellow_Smiley_Face.png
new file mode 100644
index 0000000000..7e303881e2
--- /dev/null
+++ b/tensorflow/contrib/image/python/kernel_tests/test_data/Yellow_Smiley_Face.png
Binary files differ
diff --git a/tensorflow/contrib/image/python/kernel_tests/test_data/Yellow_Smiley_Face_Warp-interp-1-clamp-0.png b/tensorflow/contrib/image/python/kernel_tests/test_data/Yellow_Smiley_Face_Warp-interp-1-clamp-0.png
new file mode 100644
index 0000000000..7fd9e4e6d6
--- /dev/null
+++ b/tensorflow/contrib/image/python/kernel_tests/test_data/Yellow_Smiley_Face_Warp-interp-1-clamp-0.png
Binary files differ
diff --git a/tensorflow/contrib/image/python/kernel_tests/test_data/Yellow_Smiley_Face_Warp-interp-1-clamp-1.png b/tensorflow/contrib/image/python/kernel_tests/test_data/Yellow_Smiley_Face_Warp-interp-1-clamp-1.png
new file mode 100644
index 0000000000..86d225e5d2
--- /dev/null
+++ b/tensorflow/contrib/image/python/kernel_tests/test_data/Yellow_Smiley_Face_Warp-interp-1-clamp-1.png
Binary files differ
diff --git a/tensorflow/contrib/image/python/kernel_tests/test_data/Yellow_Smiley_Face_Warp-interp-1-clamp-4.png b/tensorflow/contrib/image/python/kernel_tests/test_data/Yellow_Smiley_Face_Warp-interp-1-clamp-4.png
new file mode 100644
index 0000000000..37e8ffae11
--- /dev/null
+++ b/tensorflow/contrib/image/python/kernel_tests/test_data/Yellow_Smiley_Face_Warp-interp-1-clamp-4.png
Binary files differ
diff --git a/tensorflow/contrib/image/python/kernel_tests/test_data/Yellow_Smiley_Face_Warp-interp-2-clamp-0.png b/tensorflow/contrib/image/python/kernel_tests/test_data/Yellow_Smiley_Face_Warp-interp-2-clamp-0.png
new file mode 100644
index 0000000000..e49b581612
--- /dev/null
+++ b/tensorflow/contrib/image/python/kernel_tests/test_data/Yellow_Smiley_Face_Warp-interp-2-clamp-0.png
Binary files differ
diff --git a/tensorflow/contrib/image/python/kernel_tests/test_data/Yellow_Smiley_Face_Warp-interp-2-clamp-1.png b/tensorflow/contrib/image/python/kernel_tests/test_data/Yellow_Smiley_Face_Warp-interp-2-clamp-1.png
new file mode 100644
index 0000000000..df3cf20043
--- /dev/null
+++ b/tensorflow/contrib/image/python/kernel_tests/test_data/Yellow_Smiley_Face_Warp-interp-2-clamp-1.png
Binary files differ
diff --git a/tensorflow/contrib/image/python/kernel_tests/test_data/Yellow_Smiley_Face_Warp-interp-2-clamp-4.png b/tensorflow/contrib/image/python/kernel_tests/test_data/Yellow_Smiley_Face_Warp-interp-2-clamp-4.png
new file mode 100644
index 0000000000..e1799a87c8
--- /dev/null
+++ b/tensorflow/contrib/image/python/kernel_tests/test_data/Yellow_Smiley_Face_Warp-interp-2-clamp-4.png
Binary files differ
diff --git a/tensorflow/contrib/image/python/kernel_tests/test_data/Yellow_Smiley_Face_Warp-interp-3-clamp-0.png b/tensorflow/contrib/image/python/kernel_tests/test_data/Yellow_Smiley_Face_Warp-interp-3-clamp-0.png
new file mode 100644
index 0000000000..2c346e0ce5
--- /dev/null
+++ b/tensorflow/contrib/image/python/kernel_tests/test_data/Yellow_Smiley_Face_Warp-interp-3-clamp-0.png
Binary files differ
diff --git a/tensorflow/contrib/image/python/kernel_tests/test_data/Yellow_Smiley_Face_Warp-interp-3-clamp-1.png b/tensorflow/contrib/image/python/kernel_tests/test_data/Yellow_Smiley_Face_Warp-interp-3-clamp-1.png
new file mode 100644
index 0000000000..6f8b65451c
--- /dev/null
+++ b/tensorflow/contrib/image/python/kernel_tests/test_data/Yellow_Smiley_Face_Warp-interp-3-clamp-1.png
Binary files differ
diff --git a/tensorflow/contrib/image/python/kernel_tests/test_data/Yellow_Smiley_Face_Warp-interp-3-clamp-4.png b/tensorflow/contrib/image/python/kernel_tests/test_data/Yellow_Smiley_Face_Warp-interp-3-clamp-4.png
new file mode 100644
index 0000000000..8e78146d95
--- /dev/null
+++ b/tensorflow/contrib/image/python/kernel_tests/test_data/Yellow_Smiley_Face_Warp-interp-3-clamp-4.png
Binary files differ