diff options
author | 2018-08-17 04:07:15 -0700 | |
---|---|---|
committer | 2018-08-17 04:11:11 -0700 | |
commit | 0cffe2dba0c2000a8c719c2ed499a3ee72d6a2b6 (patch) | |
tree | a22bfa4dbcc98c776692e0abda0310590e09428e /tensorflow/contrib/image | |
parent | fa90b3f9173ef315f98416c52a74fc669de60454 (diff) |
update tf.contrib.image.interpolate_spline to support inputs with partially-specified shapes.
Fixes #21136
PiperOrigin-RevId: 209130870
Diffstat (limited to 'tensorflow/contrib/image')
-rw-r--r-- | tensorflow/contrib/image/python/kernel_tests/interpolate_spline_test.py | 76 | ||||
-rw-r--r-- | tensorflow/contrib/image/python/ops/interpolate_spline.py | 35 |
2 files changed, 97 insertions, 14 deletions
diff --git a/tensorflow/contrib/image/python/kernel_tests/interpolate_spline_test.py b/tensorflow/contrib/image/python/kernel_tests/interpolate_spline_test.py index 1939caaa2d..3054128979 100644 --- a/tensorflow/contrib/image/python/kernel_tests/interpolate_spline_test.py +++ b/tensorflow/contrib/image/python/kernel_tests/interpolate_spline_test.py @@ -26,6 +26,7 @@ from tensorflow.python.framework import constant_op from tensorflow.python.framework import ops from tensorflow.python.framework import test_util +from tensorflow.python.ops import array_ops from tensorflow.python.ops import clip_ops from tensorflow.python.ops import gradients from tensorflow.python.ops import math_ops @@ -226,6 +227,81 @@ class InterpolateSplineTest(test_util.TensorFlowTestCase): interp_val = sess.run(interpolator) self.assertAllClose(interp_val[0, :, 0], target_interpolation) + def test_nd_linear_interpolation_unspecified_shape(self): + """Ensure that interpolation supports dynamic batch_size and num_points.""" + + tp = _QuadraticPlusSinProblemND() + (query_points, _, train_points, + train_values) = tp.get_problem(dtype='float64') + + # Construct placeholders such that the batch size, number of train points, + # and number of query points are not known at graph construction time. + feature_dim = query_points.shape[-1] + value_dim = train_values.shape[-1] + train_points_ph = array_ops.placeholder( + dtype=train_points.dtype, shape=[None, None, feature_dim]) + train_values_ph = array_ops.placeholder( + dtype=train_values.dtype, shape=[None, None, value_dim]) + query_points_ph = array_ops.placeholder( + dtype=query_points.dtype, shape=[None, None, feature_dim]) + + order = 1 + reg_weight = 0.01 + + interpolator = interpolate_spline.interpolate_spline( + train_points_ph, train_values_ph, query_points_ph, order, reg_weight) + + target_interpolation = tp.HARDCODED_QUERY_VALUES[(order, reg_weight)] + target_interpolation = np.array(target_interpolation) + with self.test_session() as sess: + + (train_points_value, train_values_value, query_points_value) = sess.run( + [train_points, train_values, query_points]) + + interp_val = sess.run( + interpolator, + feed_dict={ + train_points_ph: train_points_value, + train_values_ph: train_values_value, + query_points_ph: query_points_value + }) + self.assertAllClose(interp_val[0, :, 0], target_interpolation) + + def test_fully_unspecified_shape(self): + """Ensure that erreor is thrown when input/output dim unspecified.""" + + tp = _QuadraticPlusSinProblemND() + (query_points, _, train_points, + train_values) = tp.get_problem(dtype='float64') + + # Construct placeholders such that the batch size, number of train points, + # and number of query points are not known at graph construction time. + feature_dim = query_points.shape[-1] + value_dim = train_values.shape[-1] + train_points_ph = array_ops.placeholder( + dtype=train_points.dtype, shape=[None, None, feature_dim]) + train_points_ph_invalid = array_ops.placeholder( + dtype=train_points.dtype, shape=[None, None, None]) + train_values_ph = array_ops.placeholder( + dtype=train_values.dtype, shape=[None, None, value_dim]) + train_values_ph_invalid = array_ops.placeholder( + dtype=train_values.dtype, shape=[None, None, None]) + query_points_ph = array_ops.placeholder( + dtype=query_points.dtype, shape=[None, None, feature_dim]) + + order = 1 + reg_weight = 0.01 + + with self.assertRaises(ValueError): + _ = interpolate_spline.interpolate_spline( + train_points_ph_invalid, train_values_ph, query_points_ph, order, + reg_weight) + + with self.assertRaises(ValueError): + _ = interpolate_spline.interpolate_spline( + train_points_ph, train_values_ph_invalid, query_points_ph, order, + reg_weight) + def test_interpolation_gradient(self): """Make sure that backprop can run. Correctness of gradients is assumed. diff --git a/tensorflow/contrib/image/python/ops/interpolate_spline.py b/tensorflow/contrib/image/python/ops/interpolate_spline.py index daf8c56456..f0b408faa3 100644 --- a/tensorflow/contrib/image/python/ops/interpolate_spline.py +++ b/tensorflow/contrib/image/python/ops/interpolate_spline.py @@ -17,9 +17,6 @@ from __future__ import absolute_import from __future__ import division from __future__ import print_function -import numpy as np - -from tensorflow.python.framework import constant_op from tensorflow.python.framework import ops from tensorflow.python.ops import array_ops from tensorflow.python.ops import linalg_ops @@ -95,10 +92,22 @@ def _solve_interpolation(train_points, train_values, order, Returns: w: `[b, n, k]` weights on each interpolation center v: `[b, d, k]` weights on each input dimension + Raises: + ValueError: if d or k is not fully specified. """ - b, n, d = train_points.get_shape().as_list() - _, _, k = train_values.get_shape().as_list() + # These dimensions are set dynamically at runtime. + b, n, _ = array_ops.unstack(array_ops.shape(train_points), num=3) + + d = train_points.shape[-1] + if d.value is None: + raise ValueError('The dimensionality of the input points (d) must be ' + 'statically-inferrable.') + + k = train_values.shape[-1] + if k.value is None: + raise ValueError('The dimensionality of the output values (k) must be ' + 'statically-inferrable.') # First, rename variables so that the notation (c, f, w, v, A, B, etc.) # follows https://en.wikipedia.org/wiki/Polyharmonic_spline. @@ -113,14 +122,12 @@ def _solve_interpolation(train_points, train_values, order, matrix_a = _phi(_pairwise_squared_distance_matrix(c), order) # [b, n, n] if regularization_weight > 0: - batch_identity_matrix = np.expand_dims(np.eye(n), 0) - batch_identity_matrix = constant_op.constant( - batch_identity_matrix, dtype=train_points.dtype) - + batch_identity_matrix = array_ops.expand_dims( + linalg_ops.eye(n, dtype=c.dtype), 0) matrix_a += regularization_weight * batch_identity_matrix # Append ones to the feature values for the bias term in the linear model. - ones = array_ops.ones([b, n, 1], train_points.dtype) + ones = array_ops.ones_like(c[..., :1], dtype=c.dtype) matrix_b = array_ops.concat([c, ones], 2) # [b, n, d + 1] # [b, n + d + 1, n] @@ -164,9 +171,6 @@ def _apply_interpolation(query_points, train_points, w, v, order): Polyharmonic interpolation evaluated at points defined in query_points. """ - batch_size = train_points.get_shape()[0].value - num_query_points = query_points.get_shape()[1].value - # First, compute the contribution from the rbf term. pairwise_dists = _cross_squared_distance_matrix(query_points, train_points) phi_pairwise_dists = _phi(pairwise_dists, order) @@ -177,7 +181,7 @@ def _apply_interpolation(query_points, train_points, w, v, order): # Pad query_points with ones, for the bias term in the linear model. query_points_pad = array_ops.concat([ query_points, - array_ops.ones([batch_size, num_query_points, 1], train_points.dtype) + array_ops.ones_like(query_points[..., :1], train_points.dtype) ], 2) linear_term = math_ops.matmul(query_points_pad, v) @@ -251,6 +255,9 @@ def interpolate_spline(train_points, Note the interpolation procedure is differentiable with respect to all inputs besides the order parameter. + We support dynamically-shaped inputs, where batch_size, n, and m are None + at graph construction time. However, d and k must be known. + Args: train_points: `[batch_size, n, d]` float `Tensor` of n d-dimensional locations. These do not need to be regularly-spaced. |