aboutsummaryrefslogtreecommitdiffhomepage
diff options
context:
space:
mode:
authorGravatar A. Unique TensorFlower <gardener@tensorflow.org>2018-06-12 08:42:53 -0700
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2018-06-12 08:45:53 -0700
commitdeb845fc79bcfe4d534a7050cc8e342f86db9dd0 (patch)
treea38d9bf033b2a29074958ddd1cb073d1026a566f
parentdf1f2a0964faf66677c30cf56526b568d355597f (diff)
Added optional argument to specify time step to contrib.integrate.odeint_fixed.
PiperOrigin-RevId: 200220800
-rw-r--r--tensorflow/contrib/integrate/python/ops/odes.py126
-rw-r--r--tensorflow/contrib/integrate/python/ops/odes_test.py51
2 files changed, 147 insertions, 30 deletions
diff --git a/tensorflow/contrib/integrate/python/ops/odes.py b/tensorflow/contrib/integrate/python/ops/odes.py
index b4a99867ed..61f78febfc 100644
--- a/tensorflow/contrib/integrate/python/ops/odes.py
+++ b/tensorflow/contrib/integrate/python/ops/odes.py
@@ -28,7 +28,6 @@ from tensorflow.python.framework import dtypes
from tensorflow.python.framework import ops
from tensorflow.python.ops import array_ops
from tensorflow.python.ops import control_flow_ops
-from tensorflow.python.ops import functional_ops
from tensorflow.python.ops import math_ops
from tensorflow.python.ops import tensor_array_ops
@@ -279,13 +278,27 @@ def _assert_increasing(t):
return ops.control_dependencies([assert_increasing])
-def _check_input_types(t, y0):
+def _check_input_types(y0, t, dt=None):
if not (y0.dtype.is_floating or y0.dtype.is_complex):
raise TypeError('`y0` must have a floating point or complex floating '
'point dtype')
if not t.dtype.is_floating:
raise TypeError('`t` must have a floating point dtype')
+ if dt is not None and not dt.dtype.is_floating:
+ raise TypeError('`dt` must have a floating point dtype')
+
+
+def _check_input_sizes(t, dt):
+ if len(t.get_shape().as_list()) > 1:
+ raise ValueError('t must be a 1D tensor')
+
+ if len(dt.get_shape().as_list()) > 1:
+ raise ValueError('t must be a 1D tensor')
+
+ if t.get_shape()[0] != dt.get_shape()[0] + 1:
+ raise ValueError('t and dt have incompatible lengths, must be N and N-1')
+
def _dopri5(func,
y0,
@@ -510,7 +523,7 @@ def odeint(func,
# avoiding the need to pack/unpack in user functions.
y0 = ops.convert_to_tensor(y0, name='y0')
t = ops.convert_to_tensor(t, preferred_dtype=dtypes.float64, name='t')
- _check_input_types(t, y0)
+ _check_input_types(y0, t)
error_dtype = abs(y0).dtype
rtol = ops.convert_to_tensor(rtol, dtype=error_dtype, name='rtol')
@@ -530,24 +543,74 @@ def odeint(func,
class _FixedGridIntegrator(six.with_metaclass(abc.ABCMeta)):
"""Base class for fixed-grid ODE integrators."""
- def integrate(self, evol_func, y0, time_grid):
- time_delta_grid = time_grid[1:] - time_grid[:-1]
-
- scan_func = self._make_scan_func(evol_func)
+ def integrate(self, evol_func, y0, time_grid, dt_grid, steps_on_intervals):
+ """Returns integrated values of differential equation on the `time grid`.
+
+ Numerically integrates differential equation defined via time derivative
+ evaluator `evol_func` using fixed time steps specified in dt_grid.
+
+ Args:
+ evol_func: Callable, evaluates time derivative of y at a given time.
+ y0: N-D Tensor holds initial values of the solution.
+ time_grid: 1-D Tensor holding the time points at which the solution
+ will be recorded, must have a floating dtype.
+ dt_grid: 1-D Tensor holds fixed time steps to be used on time_grid
+ intervals. Must be a floating dtype and have one less element than that
+ of the time_grid.
+ steps_on_intervals: 1-D Tensor of integer dtype, must have the same size
+ as dt_grid. Specifies number of steps needed for every interval. Assumes
+ steps_on_intervals * dt_grid == time intervals.
+
+ Returns:
+ (N+1)-D tensor, where the first dimension corresponds to different
+ time points. Contains the solved value of y for each desired time point in
+ `t`, with the initial value `y0` being the first element along the first
+ dimension.
+ """
- y_grid = functional_ops.scan(scan_func, (time_grid[:-1], time_delta_grid),
- y0)
- return array_ops.concat([[y0], y_grid], axis=0)
+ iteration_func = self._make_iteration_func(evol_func, dt_grid)
+ integrate_interval = self._make_interval_integrator(iteration_func,
+ steps_on_intervals)
- def _make_scan_func(self, evol_func):
+ num_times = array_ops.size(time_grid)
+ current_time = time_grid[0]
+ solution_array = tensor_array_ops.TensorArray(y0.dtype, num_times)
+ solution_array = solution_array.write(0, y0)
- def scan_func(y, t_and_dt):
- t, dt = t_and_dt
+ solution_array, _, _, _ = control_flow_ops.while_loop(
+ lambda _, __, ___, i: i < num_times,
+ integrate_interval,
+ (solution_array, y0, current_time, 1)
+ )
+ solution_array = solution_array.stack()
+ solution_array.set_shape(time_grid.get_shape().concatenate(y0.get_shape()))
+ return solution_array
+
+ def _make_iteration_func(self, evol_func, dt_grid):
+ """Returns a function that builds operations of a single time step."""
+
+ def iteration_func(y, t, dt_step, interval_step):
+ """Performs a single time step advance."""
+ dt = dt_grid[interval_step - 1]
dy = self._step_func(evol_func, t, dt, y)
dy = math_ops.cast(dy, dtype=y.dtype)
- return y + dy
+ return y + dy, t + dt, dt_step + 1, interval_step
+
+ return iteration_func
+
+ def _make_interval_integrator(self, iteration_func, interval_sizes):
+ """Returns a function that builds operations for interval integration."""
- return scan_func
+ def integrate_interval(solution_array, y, t, interval_num):
+ """Integrates y with fixed time step on interval `interval_num`."""
+ y, t, _, _ = control_flow_ops.while_loop(
+ lambda _, __, j, interval_num: j < interval_sizes[interval_num - 1],
+ iteration_func,
+ (y, t, 0, interval_num)
+ )
+ return solution_array.write(interval_num, y), y, t, interval_num + 1
+
+ return integrate_interval
@abc.abstractmethod
def _step_func(self, evol_func, t, dt, y):
@@ -555,6 +618,7 @@ class _FixedGridIntegrator(six.with_metaclass(abc.ABCMeta)):
class _MidpointFixedGridIntegrator(_FixedGridIntegrator):
+ """Fixed grid integrator implementing midpoint scheme."""
def _step_func(self, evol_func, t, dt, y):
dt_cast = math_ops.cast(dt, y.dtype)
@@ -563,6 +627,7 @@ class _MidpointFixedGridIntegrator(_FixedGridIntegrator):
class _RK4FixedGridIntegrator(_FixedGridIntegrator):
+ """Fixed grid integrator implementing RK4 scheme."""
def _step_func(self, evol_func, t, dt, y):
k1 = evol_func(y, t)
@@ -575,7 +640,7 @@ class _RK4FixedGridIntegrator(_FixedGridIntegrator):
return math_ops.add_n([k1, 2 * k2, 2 * k3, k4]) * (dt_cast / 6)
-def odeint_fixed(func, y0, t, method='rk4', name=None):
+def odeint_fixed(func, y0, t, dt=None, method='rk4', name=None):
"""ODE integration on a fixed grid (with no step size control).
Useful in certain scenarios to avoid the overhead of adaptive step size
@@ -590,6 +655,14 @@ def odeint_fixed(func, y0, t, method='rk4', name=None):
`y`. The initial time point should be the first element of this sequence,
and each time must be larger than the previous time. May have any floating
point dtype.
+ dt: 0-D or 1-D Tensor providing time step suggestion to be used on time
+ integration intervals in `t`. 1-D Tensor should provide values
+ for all intervals, must have 1 less element than that of `t`.
+ If given a 0-D Tensor, the value is interpreted as time step suggestion
+ same for all intervals. If passed None, then time step is set to be the
+ t[1:] - t[:-1]. Defaults to None. The actual step size is obtained by
+ insuring an integer number of steps per interval, potentially reducing the
+ time step.
method: One of 'midpoint' or 'rk4'.
name: Optional name for the resulting operation.
@@ -602,16 +675,29 @@ def odeint_fixed(func, y0, t, method='rk4', name=None):
Raises:
ValueError: Upon caller errors.
"""
- with ops.name_scope(name, 'odeint_fixed', [y0, t]):
+ with ops.name_scope(name, 'odeint_fixed', [y0, t, dt]):
t = ops.convert_to_tensor(t, preferred_dtype=dtypes.float64, name='t')
y0 = ops.convert_to_tensor(y0, name='y0')
- _check_input_types(t, y0)
+
+ intervals = t[1:] - t[:-1]
+ if dt is None:
+ dt = intervals
+ dt = ops.convert_to_tensor(dt, preferred_dtype=dtypes.float64, name='dt')
+
+ steps_on_intervals = math_ops.ceil(intervals / dt)
+ dt = intervals / steps_on_intervals
+ steps_on_intervals = math_ops.cast(steps_on_intervals, dtype=dtypes.int32)
+
+ _check_input_types(y0, t, dt)
+ _check_input_sizes(t, dt)
with _assert_increasing(t):
with ops.name_scope(method):
if method == 'midpoint':
- return _MidpointFixedGridIntegrator().integrate(func, y0, t)
+ return _MidpointFixedGridIntegrator().integrate(func, y0, t, dt,
+ steps_on_intervals)
elif method == 'rk4':
- return _RK4FixedGridIntegrator().integrate(func, y0, t)
+ return _RK4FixedGridIntegrator().integrate(func, y0, t, dt,
+ steps_on_intervals)
else:
raise ValueError('method not supported: {!s}'.format(method))
diff --git a/tensorflow/contrib/integrate/python/ops/odes_test.py b/tensorflow/contrib/integrate/python/ops/odes_test.py
index 3ec01212d2..c7b4e2faa8 100644
--- a/tensorflow/contrib/integrate/python/ops/odes_test.py
+++ b/tensorflow/contrib/integrate/python/ops/odes_test.py
@@ -242,40 +242,56 @@ class InterpolationTest(test.TestCase):
class OdeIntFixedTest(test.TestCase):
- def _test_integrate_sine(self, method):
+ def _test_integrate_sine(self, method, t, dt=None):
def evol_func(y, t):
del t
return array_ops.stack([y[1], -y[0]])
y0 = [0., 1.]
- time_grid = np.linspace(0., 10., 200)
- y_grid = odes.odeint_fixed(evol_func, y0, time_grid, method=method)
+ y_grid = odes.odeint_fixed(evol_func, y0, t, dt, method=method)
with self.test_session() as sess:
y_grid_array = sess.run(y_grid)
np.testing.assert_allclose(
- y_grid_array[:, 0], np.sin(time_grid), rtol=1e-2, atol=1e-2)
+ y_grid_array[:, 0], np.sin(t), rtol=1e-2, atol=1e-2)
- def _test_integrate_gaussian(self, method):
+ def _test_integrate_gaussian(self, method, t, dt=None):
def evol_func(y, t):
return -math_ops.cast(t, dtype=y.dtype) * y[0]
y0 = [1.]
- time_grid = np.linspace(0., 2., 100)
- y_grid = odes.odeint_fixed(evol_func, y0, time_grid, method=method)
+ y_grid = odes.odeint_fixed(evol_func, y0, t, dt, method=method)
with self.test_session() as sess:
y_grid_array = sess.run(y_grid)
np.testing.assert_allclose(
- y_grid_array[:, 0], np.exp(-time_grid**2 / 2), rtol=1e-2, atol=1e-2)
+ y_grid_array[:, 0], np.exp(-t**2 / 2), rtol=1e-2, atol=1e-2)
+
+ def _test_integrate_sine_all(self, method):
+ uniform_time_grid = np.linspace(0., 10., 200)
+ non_uniform_time_grid = np.asarray([0.0, 0.4, 4.7, 5.2, 7.0])
+ uniform_dt = 0.02
+ non_uniform_dt = np.asarray([0.01, 0.001, 0.05, 0.03])
+ self._test_integrate_sine(method, uniform_time_grid)
+ self._test_integrate_sine(method, non_uniform_time_grid, uniform_dt)
+ self._test_integrate_sine(method, non_uniform_time_grid, non_uniform_dt)
+
+ def _test_integrate_gaussian_all(self, method):
+ uniform_time_grid = np.linspace(0., 2., 100)
+ non_uniform_time_grid = np.asarray([0.0, 0.1, 0.7, 1.2, 2.0])
+ uniform_dt = 0.01
+ non_uniform_dt = np.asarray([0.01, 0.001, 0.1, 0.03])
+ self._test_integrate_gaussian(method, uniform_time_grid)
+ self._test_integrate_gaussian(method, non_uniform_time_grid, uniform_dt)
+ self._test_integrate_gaussian(method, non_uniform_time_grid, non_uniform_dt)
def _test_everything(self, method):
- self._test_integrate_sine(method)
- self._test_integrate_gaussian(method)
+ self._test_integrate_sine_all(method)
+ self._test_integrate_gaussian_all(method)
def test_midpoint(self):
self._test_everything('midpoint')
@@ -283,6 +299,21 @@ class OdeIntFixedTest(test.TestCase):
def test_rk4(self):
self._test_everything('rk4')
+ def test_dt_size_exceptions(self):
+ times = np.linspace(0., 2., 100)
+ dt = np.ones(99) * 0.01
+ dt_wrong_length = np.asarray([0.01, 0.001, 0.1, 0.03])
+ dt_wrong_dim = np.expand_dims(np.linspace(0., 2., 99), axis=0)
+ times_wrong_dim = np.expand_dims(np.linspace(0., 2., 100), axis=0)
+ with self.assertRaises(ValueError):
+ self._test_integrate_gaussian('midpoint', times, dt_wrong_length)
+
+ with self.assertRaises(ValueError):
+ self._test_integrate_gaussian('midpoint', times, dt_wrong_dim)
+
+ with self.assertRaises(ValueError):
+ self._test_integrate_gaussian('midpoint', times_wrong_dim, dt)
+
if __name__ == '__main__':
test.main()