aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/contrib/image/python/ops/interpolate_spline.py
blob: f0b408faa3320741cf83b3aaec0f40030f906578 (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
# Copyright 2018 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Polyharmonic spline interpolation."""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function

from tensorflow.python.framework import ops
from tensorflow.python.ops import array_ops
from tensorflow.python.ops import linalg_ops
from tensorflow.python.ops import math_ops

EPSILON = 0.0000000001


def _cross_squared_distance_matrix(x, y):
  """Pairwise squared distance between two (batch) matrices' rows (2nd dim).

  Computes the pairwise distances between rows of x and rows of y
  Args:
    x: [batch_size, n, d] float `Tensor`
    y: [batch_size, m, d] float `Tensor`

  Returns:
    squared_dists: [batch_size, n, m] float `Tensor`, where
    squared_dists[b,i,j] = ||x[b,i,:] - y[b,j,:]||^2
  """
  x_norm_squared = math_ops.reduce_sum(math_ops.square(x), 2)
  y_norm_squared = math_ops.reduce_sum(math_ops.square(y), 2)

  # Expand so that we can broadcast.
  x_norm_squared_tile = array_ops.expand_dims(x_norm_squared, 2)
  y_norm_squared_tile = array_ops.expand_dims(y_norm_squared, 1)

  x_y_transpose = math_ops.matmul(x, y, adjoint_b=True)

  # squared_dists[b,i,j] = ||x_bi - y_bj||^2 = x_bi'x_bi- 2x_bi'x_bj + x_bj'x_bj
  squared_dists = x_norm_squared_tile - 2 * x_y_transpose + y_norm_squared_tile

  return squared_dists


def _pairwise_squared_distance_matrix(x):
  """Pairwise squared distance among a (batch) matrix's rows (2nd dim).

  This saves a bit of computation vs. using _cross_squared_distance_matrix(x,x)

  Args:
    x: `[batch_size, n, d]` float `Tensor`

  Returns:
    squared_dists: `[batch_size, n, n]` float `Tensor`, where
    squared_dists[b,i,j] = ||x[b,i,:] - x[b,j,:]||^2
  """

  x_x_transpose = math_ops.matmul(x, x, adjoint_b=True)
  x_norm_squared = array_ops.matrix_diag_part(x_x_transpose)
  x_norm_squared_tile = array_ops.expand_dims(x_norm_squared, 2)

  # squared_dists[b,i,j] = ||x_bi - x_bj||^2 = x_bi'x_bi- 2x_bi'x_bj + x_bj'x_bj
  squared_dists = x_norm_squared_tile - 2 * x_x_transpose + array_ops.transpose(
      x_norm_squared_tile, [0, 2, 1])

  return squared_dists


def _solve_interpolation(train_points, train_values, order,
                         regularization_weight):
  """Solve for interpolation coefficients.

  Computes the coefficients of the polyharmonic interpolant for the 'training'
  data defined by (train_points, train_values) using the kernel phi.

  Args:
    train_points: `[b, n, d]` interpolation centers
    train_values: `[b, n, k]` function values
    order: order of the interpolation
    regularization_weight: weight to place on smoothness regularization term

  Returns:
    w: `[b, n, k]` weights on each interpolation center
    v: `[b, d, k]` weights on each input dimension
  Raises:
    ValueError: if d or k is not fully specified.
  """

  # These dimensions are set dynamically at runtime.
  b, n, _ = array_ops.unstack(array_ops.shape(train_points), num=3)

  d = train_points.shape[-1]
  if d.value is None:
    raise ValueError('The dimensionality of the input points (d) must be '
                     'statically-inferrable.')

  k = train_values.shape[-1]
  if k.value is None:
    raise ValueError('The dimensionality of the output values (k) must be '
                     'statically-inferrable.')

  # First, rename variables so that the notation (c, f, w, v, A, B, etc.)
  # follows https://en.wikipedia.org/wiki/Polyharmonic_spline.
  # To account for python style guidelines we use
  # matrix_a for A and matrix_b for B.

  c = train_points
  f = train_values

  # Next, construct the linear system.
  with ops.name_scope('construct_linear_system'):

    matrix_a = _phi(_pairwise_squared_distance_matrix(c), order)  # [b, n, n]
    if regularization_weight > 0:
      batch_identity_matrix = array_ops.expand_dims(
          linalg_ops.eye(n, dtype=c.dtype), 0)
      matrix_a += regularization_weight * batch_identity_matrix

    # Append ones to the feature values for the bias term in the linear model.
    ones = array_ops.ones_like(c[..., :1], dtype=c.dtype)
    matrix_b = array_ops.concat([c, ones], 2)  # [b, n, d + 1]

    # [b, n + d + 1, n]
    left_block = array_ops.concat(
        [matrix_a, array_ops.transpose(matrix_b, [0, 2, 1])], 1)

    num_b_cols = matrix_b.get_shape()[2]  # d + 1
    lhs_zeros = array_ops.zeros([b, num_b_cols, num_b_cols], train_points.dtype)
    right_block = array_ops.concat([matrix_b, lhs_zeros],
                                   1)  # [b, n + d + 1, d + 1]
    lhs = array_ops.concat([left_block, right_block],
                           2)  # [b, n + d + 1, n + d + 1]

    rhs_zeros = array_ops.zeros([b, d + 1, k], train_points.dtype)
    rhs = array_ops.concat([f, rhs_zeros], 1)  # [b, n + d + 1, k]

  # Then, solve the linear system and unpack the results.
  with ops.name_scope('solve_linear_system'):
    w_v = linalg_ops.matrix_solve(lhs, rhs)
    w = w_v[:, :n, :]
    v = w_v[:, n:, :]

  return w, v


def _apply_interpolation(query_points, train_points, w, v, order):
  """Apply polyharmonic interpolation model to data.

  Given coefficients w and v for the interpolation model, we evaluate
  interpolated function values at query_points.

  Args:
    query_points: `[b, m, d]` x values to evaluate the interpolation at
    train_points: `[b, n, d]` x values that act as the interpolation centers
                    ( the c variables in the wikipedia article)
    w: `[b, n, k]` weights on each interpolation center
    v: `[b, d, k]` weights on each input dimension
    order: order of the interpolation

  Returns:
    Polyharmonic interpolation evaluated at points defined in query_points.
  """

  # First, compute the contribution from the rbf term.
  pairwise_dists = _cross_squared_distance_matrix(query_points, train_points)
  phi_pairwise_dists = _phi(pairwise_dists, order)

  rbf_term = math_ops.matmul(phi_pairwise_dists, w)

  # Then, compute the contribution from the linear term.
  # Pad query_points with ones, for the bias term in the linear model.
  query_points_pad = array_ops.concat([
      query_points,
      array_ops.ones_like(query_points[..., :1], train_points.dtype)
  ], 2)
  linear_term = math_ops.matmul(query_points_pad, v)

  return rbf_term + linear_term


def _phi(r, order):
  """Coordinate-wise nonlinearity used to define the order of the interpolation.

  See https://en.wikipedia.org/wiki/Polyharmonic_spline for the definition.

  Args:
    r: input op
    order: interpolation order

  Returns:
    phi_k evaluated coordinate-wise on r, for k = r
  """

  # using EPSILON prevents log(0), sqrt0), etc.
  # sqrt(0) is well-defined, but its gradient is not
  with ops.name_scope('phi'):
    if order == 1:
      r = math_ops.maximum(r, EPSILON)
      r = math_ops.sqrt(r)
      return r
    elif order == 2:
      return 0.5 * r * math_ops.log(math_ops.maximum(r, EPSILON))
    elif order == 4:
      return 0.5 * math_ops.square(r) * math_ops.log(
          math_ops.maximum(r, EPSILON))
    elif order % 2 == 0:
      r = math_ops.maximum(r, EPSILON)
      return 0.5 * math_ops.pow(r, 0.5 * order) * math_ops.log(r)
    else:
      r = math_ops.maximum(r, EPSILON)
      return math_ops.pow(r, 0.5 * order)


def interpolate_spline(train_points,
                       train_values,
                       query_points,
                       order,
                       regularization_weight=0.0,
                       name='interpolate_spline'):
  r"""Interpolate signal using polyharmonic interpolation.

  The interpolant has the form
  $$f(x) = \sum_{i = 1}^n w_i \phi(||x - c_i||) + v^T x + b.$$

  This is a sum of two terms: (1) a weighted sum of radial basis function (RBF)
  terms, with the centers \\(c_1, ... c_n\\), and (2) a linear term with a bias.
  The \\(c_i\\) vectors are 'training' points. In the code, b is absorbed into v
  by appending 1 as a final dimension to x. The coefficients w and v are
  estimated such that the interpolant exactly fits the value of the function at
  the \\(c_i\\) points, the vector w is orthogonal to each \\(c_i\\), and the
  vector w sums to 0. With these constraints, the coefficients can be obtained
  by solving a linear system.

  \\(\phi\\) is an RBF, parametrized by an interpolation
  order. Using order=2 produces the well-known thin-plate spline.

  We also provide the option to perform regularized interpolation. Here, the
  interpolant is selected to trade off between the squared loss on the training
  data and a certain measure of its curvature
  ([details](https://en.wikipedia.org/wiki/Polyharmonic_spline)).
  Using a regularization weight greater than zero has the effect that the
  interpolant will no longer exactly fit the training data. However, it may be
  less vulnerable to overfitting, particularly for high-order interpolation.

  Note the interpolation procedure is differentiable with respect to all inputs
  besides the order parameter.

  We support dynamically-shaped inputs, where batch_size, n, and m are None
  at graph construction time. However, d and k must be known.

  Args:
    train_points: `[batch_size, n, d]` float `Tensor` of n d-dimensional
      locations. These do not need to be regularly-spaced.
    train_values: `[batch_size, n, k]` float `Tensor` of n c-dimensional values
      evaluated at train_points.
    query_points: `[batch_size, m, d]` `Tensor` of m d-dimensional locations
      where we will output the interpolant's values.
    order: order of the interpolation. Common values are 1 for
      \\(\phi(r) = r\\), 2 for \\(\phi(r) = r^2 * log(r)\\) (thin-plate spline),
       or 3 for \\(\phi(r) = r^3\\).
    regularization_weight: weight placed on the regularization term.
      This will depend substantially on the problem, and it should always be
      tuned. For many problems, it is reasonable to use no regularization.
      If using a non-zero value, we recommend a small value like 0.001.
    name: name prefix for ops created by this function

  Returns:
    `[b, m, k]` float `Tensor` of query values. We use train_points and
    train_values to perform polyharmonic interpolation. The query values are
    the values of the interpolant evaluated at the locations specified in
    query_points.
  """
  with ops.name_scope(name):
    train_points = ops.convert_to_tensor(train_points)
    train_values = ops.convert_to_tensor(train_values)
    query_points = ops.convert_to_tensor(query_points)

    # First, fit the spline to the observed data.
    with ops.name_scope('solve'):
      w, v = _solve_interpolation(train_points, train_values, order,
                                  regularization_weight)

    # Then, evaluate the spline at the query locations.
    with ops.name_scope('predict'):
      query_values = _apply_interpolation(query_points, train_points, w, v,
                                          order)

  return query_values