aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/contrib/image
diff options
context:
space:
mode:
authorGravatar A. Unique TensorFlower <gardener@tensorflow.org>2018-08-17 04:07:15 -0700
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2018-08-17 04:11:11 -0700
commit0cffe2dba0c2000a8c719c2ed499a3ee72d6a2b6 (patch)
treea22bfa4dbcc98c776692e0abda0310590e09428e /tensorflow/contrib/image
parentfa90b3f9173ef315f98416c52a74fc669de60454 (diff)
update tf.contrib.image.interpolate_spline to support inputs with partially-specified shapes.
Fixes #21136 PiperOrigin-RevId: 209130870
Diffstat (limited to 'tensorflow/contrib/image')
-rw-r--r--tensorflow/contrib/image/python/kernel_tests/interpolate_spline_test.py76
-rw-r--r--tensorflow/contrib/image/python/ops/interpolate_spline.py35
2 files changed, 97 insertions, 14 deletions
diff --git a/tensorflow/contrib/image/python/kernel_tests/interpolate_spline_test.py b/tensorflow/contrib/image/python/kernel_tests/interpolate_spline_test.py
index 1939caaa2d..3054128979 100644
--- a/tensorflow/contrib/image/python/kernel_tests/interpolate_spline_test.py
+++ b/tensorflow/contrib/image/python/kernel_tests/interpolate_spline_test.py
@@ -26,6 +26,7 @@ from tensorflow.python.framework import constant_op
from tensorflow.python.framework import ops
from tensorflow.python.framework import test_util
+from tensorflow.python.ops import array_ops
from tensorflow.python.ops import clip_ops
from tensorflow.python.ops import gradients
from tensorflow.python.ops import math_ops
@@ -226,6 +227,81 @@ class InterpolateSplineTest(test_util.TensorFlowTestCase):
interp_val = sess.run(interpolator)
self.assertAllClose(interp_val[0, :, 0], target_interpolation)
+ def test_nd_linear_interpolation_unspecified_shape(self):
+ """Ensure that interpolation supports dynamic batch_size and num_points."""
+
+ tp = _QuadraticPlusSinProblemND()
+ (query_points, _, train_points,
+ train_values) = tp.get_problem(dtype='float64')
+
+ # Construct placeholders such that the batch size, number of train points,
+ # and number of query points are not known at graph construction time.
+ feature_dim = query_points.shape[-1]
+ value_dim = train_values.shape[-1]
+ train_points_ph = array_ops.placeholder(
+ dtype=train_points.dtype, shape=[None, None, feature_dim])
+ train_values_ph = array_ops.placeholder(
+ dtype=train_values.dtype, shape=[None, None, value_dim])
+ query_points_ph = array_ops.placeholder(
+ dtype=query_points.dtype, shape=[None, None, feature_dim])
+
+ order = 1
+ reg_weight = 0.01
+
+ interpolator = interpolate_spline.interpolate_spline(
+ train_points_ph, train_values_ph, query_points_ph, order, reg_weight)
+
+ target_interpolation = tp.HARDCODED_QUERY_VALUES[(order, reg_weight)]
+ target_interpolation = np.array(target_interpolation)
+ with self.test_session() as sess:
+
+ (train_points_value, train_values_value, query_points_value) = sess.run(
+ [train_points, train_values, query_points])
+
+ interp_val = sess.run(
+ interpolator,
+ feed_dict={
+ train_points_ph: train_points_value,
+ train_values_ph: train_values_value,
+ query_points_ph: query_points_value
+ })
+ self.assertAllClose(interp_val[0, :, 0], target_interpolation)
+
+ def test_fully_unspecified_shape(self):
+ """Ensure that erreor is thrown when input/output dim unspecified."""
+
+ tp = _QuadraticPlusSinProblemND()
+ (query_points, _, train_points,
+ train_values) = tp.get_problem(dtype='float64')
+
+ # Construct placeholders such that the batch size, number of train points,
+ # and number of query points are not known at graph construction time.
+ feature_dim = query_points.shape[-1]
+ value_dim = train_values.shape[-1]
+ train_points_ph = array_ops.placeholder(
+ dtype=train_points.dtype, shape=[None, None, feature_dim])
+ train_points_ph_invalid = array_ops.placeholder(
+ dtype=train_points.dtype, shape=[None, None, None])
+ train_values_ph = array_ops.placeholder(
+ dtype=train_values.dtype, shape=[None, None, value_dim])
+ train_values_ph_invalid = array_ops.placeholder(
+ dtype=train_values.dtype, shape=[None, None, None])
+ query_points_ph = array_ops.placeholder(
+ dtype=query_points.dtype, shape=[None, None, feature_dim])
+
+ order = 1
+ reg_weight = 0.01
+
+ with self.assertRaises(ValueError):
+ _ = interpolate_spline.interpolate_spline(
+ train_points_ph_invalid, train_values_ph, query_points_ph, order,
+ reg_weight)
+
+ with self.assertRaises(ValueError):
+ _ = interpolate_spline.interpolate_spline(
+ train_points_ph, train_values_ph_invalid, query_points_ph, order,
+ reg_weight)
+
def test_interpolation_gradient(self):
"""Make sure that backprop can run. Correctness of gradients is assumed.
diff --git a/tensorflow/contrib/image/python/ops/interpolate_spline.py b/tensorflow/contrib/image/python/ops/interpolate_spline.py
index daf8c56456..f0b408faa3 100644
--- a/tensorflow/contrib/image/python/ops/interpolate_spline.py
+++ b/tensorflow/contrib/image/python/ops/interpolate_spline.py
@@ -17,9 +17,6 @@ from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
-import numpy as np
-
-from tensorflow.python.framework import constant_op
from tensorflow.python.framework import ops
from tensorflow.python.ops import array_ops
from tensorflow.python.ops import linalg_ops
@@ -95,10 +92,22 @@ def _solve_interpolation(train_points, train_values, order,
Returns:
w: `[b, n, k]` weights on each interpolation center
v: `[b, d, k]` weights on each input dimension
+ Raises:
+ ValueError: if d or k is not fully specified.
"""
- b, n, d = train_points.get_shape().as_list()
- _, _, k = train_values.get_shape().as_list()
+ # These dimensions are set dynamically at runtime.
+ b, n, _ = array_ops.unstack(array_ops.shape(train_points), num=3)
+
+ d = train_points.shape[-1]
+ if d.value is None:
+ raise ValueError('The dimensionality of the input points (d) must be '
+ 'statically-inferrable.')
+
+ k = train_values.shape[-1]
+ if k.value is None:
+ raise ValueError('The dimensionality of the output values (k) must be '
+ 'statically-inferrable.')
# First, rename variables so that the notation (c, f, w, v, A, B, etc.)
# follows https://en.wikipedia.org/wiki/Polyharmonic_spline.
@@ -113,14 +122,12 @@ def _solve_interpolation(train_points, train_values, order,
matrix_a = _phi(_pairwise_squared_distance_matrix(c), order) # [b, n, n]
if regularization_weight > 0:
- batch_identity_matrix = np.expand_dims(np.eye(n), 0)
- batch_identity_matrix = constant_op.constant(
- batch_identity_matrix, dtype=train_points.dtype)
-
+ batch_identity_matrix = array_ops.expand_dims(
+ linalg_ops.eye(n, dtype=c.dtype), 0)
matrix_a += regularization_weight * batch_identity_matrix
# Append ones to the feature values for the bias term in the linear model.
- ones = array_ops.ones([b, n, 1], train_points.dtype)
+ ones = array_ops.ones_like(c[..., :1], dtype=c.dtype)
matrix_b = array_ops.concat([c, ones], 2) # [b, n, d + 1]
# [b, n + d + 1, n]
@@ -164,9 +171,6 @@ def _apply_interpolation(query_points, train_points, w, v, order):
Polyharmonic interpolation evaluated at points defined in query_points.
"""
- batch_size = train_points.get_shape()[0].value
- num_query_points = query_points.get_shape()[1].value
-
# First, compute the contribution from the rbf term.
pairwise_dists = _cross_squared_distance_matrix(query_points, train_points)
phi_pairwise_dists = _phi(pairwise_dists, order)
@@ -177,7 +181,7 @@ def _apply_interpolation(query_points, train_points, w, v, order):
# Pad query_points with ones, for the bias term in the linear model.
query_points_pad = array_ops.concat([
query_points,
- array_ops.ones([batch_size, num_query_points, 1], train_points.dtype)
+ array_ops.ones_like(query_points[..., :1], train_points.dtype)
], 2)
linear_term = math_ops.matmul(query_points_pad, v)
@@ -251,6 +255,9 @@ def interpolate_spline(train_points,
Note the interpolation procedure is differentiable with respect to all inputs
besides the order parameter.
+ We support dynamically-shaped inputs, where batch_size, n, and m are None
+ at graph construction time. However, d and k must be known.
+
Args:
train_points: `[batch_size, n, d]` float `Tensor` of n d-dimensional
locations. These do not need to be regularly-spaced.