aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/contrib/integrate/python/ops/odes.py
blob: 61f78febfc07bb4e677259366a81c16b2b585244 (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
584
585
586
587
588
589
590
591
592
593
594
595
596
597
598
599
600
601
602
603
604
605
606
607
608
609
610
611
612
613
614
615
616
617
618
619
620
621
622
623
624
625
626
627
628
629
630
631
632
633
634
635
636
637
638
639
640
641
642
643
644
645
646
647
648
649
650
651
652
653
654
655
656
657
658
659
660
661
662
663
664
665
666
667
668
669
670
671
672
673
674
675
676
677
678
679
680
681
682
683
684
685
686
687
688
689
690
691
692
693
694
695
696
697
698
699
700
701
702
703
# Copyright 2016 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""ODE solvers for TensorFlow."""

from __future__ import absolute_import
from __future__ import division
from __future__ import print_function

import abc
import collections

import six

from tensorflow.python.framework import constant_op
from tensorflow.python.framework import dtypes
from tensorflow.python.framework import ops
from tensorflow.python.ops import array_ops
from tensorflow.python.ops import control_flow_ops
from tensorflow.python.ops import math_ops
from tensorflow.python.ops import tensor_array_ops

_ButcherTableau = collections.namedtuple('_ButcherTableau',
                                         'alpha beta c_sol c_mid c_error')

# Parameters from Shampine (1986), section 4.
_DORMAND_PRINCE_TABLEAU = _ButcherTableau(
    alpha=[1 / 5, 3 / 10, 4 / 5, 8 / 9, 1., 1.],
    beta=[
        [1 / 5],
        [3 / 40, 9 / 40],
        [44 / 45, -56 / 15, 32 / 9],
        [19372 / 6561, -25360 / 2187, 64448 / 6561, -212 / 729],
        [9017 / 3168, -355 / 33, 46732 / 5247, 49 / 176, -5103 / 18656],
        [35 / 384, 0, 500 / 1113, 125 / 192, -2187 / 6784, 11 / 84],
    ],
    c_sol=[35 / 384, 0, 500 / 1113, 125 / 192, -2187 / 6784, 11 / 84, 0],
    c_mid=[
        6025192743 / 30085553152 / 2, 0, 51252292925 / 65400821598 / 2,
        -2691868925 / 45128329728 / 2, 187940372067 / 1594534317056 / 2,
        -1776094331 / 19743644256 / 2, 11237099 / 235043384 / 2
    ],
    c_error=[
        1951 / 21600 - 35 / 384,
        0,
        22642 / 50085 - 500 / 1113,
        451 / 720 - 125 / 192,
        -12231 / 42400 - -2187 / 6784,
        649 / 6300 - 11 / 84,
        1 / 60,
    ],)


def _possibly_nonzero(x):
  return isinstance(x, ops.Tensor) or x != 0


def _scaled_dot_product(scale, xs, ys, name=None):
  """Calculate a scaled, vector inner product between lists of Tensors."""
  with ops.name_scope(name, 'scaled_dot_product', [scale, xs, ys]) as scope:
    # Some of the parameters in our Butcher tableau include zeros. Using
    # _possibly_nonzero lets us avoid wasted computation.
    return math_ops.add_n(
        [(scale * x) * y for x, y in zip(xs, ys)
         if _possibly_nonzero(x) or _possibly_nonzero(y)],
        name=scope)


def _dot_product(xs, ys, name=None):
  """Calculate the vector inner product between two lists of Tensors."""
  with ops.name_scope(name, 'dot_product', [xs, ys]) as scope:
    return math_ops.add_n([x * y for x, y in zip(xs, ys)], name=scope)


def _runge_kutta_step(func,
                      y0,
                      f0,
                      t0,
                      dt,
                      tableau=_DORMAND_PRINCE_TABLEAU,
                      name=None):
  """Take an arbitrary Runge-Kutta step and estimate error.

  Args:
    func: Function to evaluate like `func(y, t)` to compute the time derivative
      of `y`.
    y0: Tensor initial value for the state.
    f0: Tensor initial value for the derivative, computed from `func(y0, t0)`.
    t0: float64 scalar Tensor giving the initial time.
    dt: float64 scalar Tensor giving the size of the desired time step.
    tableau: optional _ButcherTableau describing how to take the Runge-Kutta
      step.
    name: optional name for the operation.

  Returns:
    Tuple `(y1, f1, y1_error, k)` giving the estimated function value after
    the Runge-Kutta step at `t1 = t0 + dt`, the derivative of the state at `t1`,
    estimated error at `t1`, and a list of Runge-Kutta coefficients `k` used for
    calculating these terms.
  """
  with ops.name_scope(name, 'runge_kutta_step', [y0, f0, t0, dt]) as scope:
    y0 = ops.convert_to_tensor(y0, name='y0')
    f0 = ops.convert_to_tensor(f0, name='f0')
    t0 = ops.convert_to_tensor(t0, name='t0')
    dt = ops.convert_to_tensor(dt, name='dt')
    dt_cast = math_ops.cast(dt, y0.dtype)

    k = [f0]
    for alpha_i, beta_i in zip(tableau.alpha, tableau.beta):
      ti = t0 + alpha_i * dt
      yi = y0 + _scaled_dot_product(dt_cast, beta_i, k)
      k.append(func(yi, ti))

    if not (tableau.c_sol[-1] == 0 and tableau.c_sol == tableau.beta[-1]):
      # This property (true for Dormand-Prince) lets us save a few FLOPs.
      yi = y0 + _scaled_dot_product(dt_cast, tableau.c_sol, k)

    y1 = array_ops.identity(yi, name='%s/y1' % scope)
    f1 = array_ops.identity(k[-1], name='%s/f1' % scope)
    y1_error = _scaled_dot_product(
        dt_cast, tableau.c_error, k, name='%s/y1_error' % scope)
    return (y1, f1, y1_error, k)


def _interp_fit(y0, y1, y_mid, f0, f1, dt):
  """Fit coefficients for 4th order polynomial interpolation.

  Args:
    y0: function value at the start of the interval.
    y1: function value at the end of the interval.
    y_mid: function value at the mid-point of the interval.
    f0: derivative value at the start of the interval.
    f1: derivative value at the end of the interval.
    dt: width of the interval.

  Returns:
    List of coefficients `[a, b, c, d, e]` for interpolating with the polynomial
    `p = a * x ** 4 + b * x ** 3 + c * x ** 2 + d * x + e` for values of `x`
    between 0 (start of interval) and 1 (end of interval).
  """
  # a, b, c, d, e = sympy.symbols('a b c d e')
  # x, dt, y0, y1, y_mid, f0, f1 = sympy.symbols('x dt y0 y1 y_mid f0 f1')
  # p = a * x ** 4 + b * x ** 3 + c * x ** 2 + d * x + e
  # sympy.solve([p.subs(x, 0) - y0,
  #              p.subs(x, 1 / 2) - y_mid,
  #              p.subs(x, 1) - y1,
  #              (p.diff(x) / dt).subs(x, 0) - f0,
  #              (p.diff(x) / dt).subs(x, 1) - f1],
  #             [a, b, c, d, e])
  # {a: -2.0*dt*f0 + 2.0*dt*f1 - 8.0*y0 - 8.0*y1 + 16.0*y_mid,
  #  b: 5.0*dt*f0 - 3.0*dt*f1 + 18.0*y0 + 14.0*y1 - 32.0*y_mid,
  #  c: -4.0*dt*f0 + dt*f1 - 11.0*y0 - 5.0*y1 + 16.0*y_mid,
  #  d: dt*f0,
  #  e: y0}
  a = _dot_product([-2 * dt, 2 * dt, -8, -8, 16], [f0, f1, y0, y1, y_mid])
  b = _dot_product([5 * dt, -3 * dt, 18, 14, -32], [f0, f1, y0, y1, y_mid])
  c = _dot_product([-4 * dt, dt, -11, -5, 16], [f0, f1, y0, y1, y_mid])
  d = dt * f0
  e = y0
  return [a, b, c, d, e]


def _interp_fit_rk(y0, y1, k, dt, tableau=_DORMAND_PRINCE_TABLEAU):
  """Fit an interpolating polynomial to the results of a Runge-Kutta step."""
  with ops.name_scope('interp_fit_rk'):
    dt = math_ops.cast(dt, y0.dtype)
    y_mid = y0 + _scaled_dot_product(dt, tableau.c_mid, k)
    f0 = k[0]
    f1 = k[-1]
    return _interp_fit(y0, y1, y_mid, f0, f1, dt)


def _interp_evaluate(coefficients, t0, t1, t):
  """Evaluate polynomial interpolation at the given time point.

  Args:
    coefficients: list of Tensor coefficients as created by `interp_fit`.
    t0: scalar float64 Tensor giving the start of the interval.
    t1: scalar float64 Tensor giving the end of the interval.
    t: scalar float64 Tensor giving the desired interpolation point.

  Returns:
    Polynomial interpolation of the coefficients at time `t`.
  """
  with ops.name_scope('interp_evaluate'):
    t0 = ops.convert_to_tensor(t0)
    t1 = ops.convert_to_tensor(t1)
    t = ops.convert_to_tensor(t)

    dtype = coefficients[0].dtype

    assert_op = control_flow_ops.Assert(
        (t0 <= t) & (t <= t1),
        ['invalid interpolation, fails `t0 <= t <= t1`:', t0, t, t1])
    with ops.control_dependencies([assert_op]):
      x = math_ops.cast((t - t0) / (t1 - t0), dtype)

    xs = [constant_op.constant(1, dtype), x]
    for _ in range(2, len(coefficients)):
      xs.append(xs[-1] * x)

    return _dot_product(coefficients, reversed(xs))


def _optimal_step_size(last_step,
                       error_ratio,
                       safety=0.9,
                       ifactor=10.0,
                       dfactor=0.2,
                       order=5,
                       name=None):
  """Calculate the optimal size for the next Runge-Kutta step."""
  with ops.name_scope(name, 'optimal_step_size', [last_step,
                                                  error_ratio]) as scope:
    error_ratio = math_ops.cast(error_ratio, last_step.dtype)
    exponent = math_ops.cast(1 / order, last_step.dtype)
    # this looks more complex than necessary, but importantly it keeps
    # error_ratio in the numerator so we can't divide by zero:
    factor = math_ops.maximum(1 / ifactor,
                              math_ops.minimum(error_ratio**exponent / safety,
                                               1 / dfactor))
    return math_ops.div(last_step, factor, name=scope)


def _abs_square(x):
  if x.dtype.is_complex:
    return math_ops.square(math_ops.real(x)) + math_ops.square(math_ops.imag(x))
  else:
    return math_ops.square(x)


def _ta_append(tensor_array, value):
  """Append a value to the end of a tf.TensorArray."""
  return tensor_array.write(tensor_array.size(), value)


class _RungeKuttaState(
    collections.namedtuple('_RungeKuttaState',
                           'y1, f1, t0, t1, dt, interp_coeff')):
  """Saved state of the Runge Kutta solver.

  Attributes:
    y1: Tensor giving the function value at the end of the last time step.
    f1: Tensor giving derivative at the end of the last time step.
    t0: scalar float64 Tensor giving start of the last time step.
    t1: scalar float64 Tensor giving end of the last time step.
    dt: scalar float64 Tensor giving the size for the next time step.
    interp_coef: list of Tensors giving coefficients for polynomial
      interpolation between `t0` and `t1`.
  """


class _History(
    collections.namedtuple('_History', 'integrate_points, error_ratio')):
  """Saved integration history for use in `info_dict`.

  Attributes:
    integrate_points: tf.TensorArray storing integrating time points.
    error_ratio: tf.TensorArray storing computed error ratios at each
      integration step.
  """


def _assert_increasing(t):
  assert_increasing = control_flow_ops.Assert(
      math_ops.reduce_all(t[1:] > t[:-1]), ['`t` must be monotonic increasing'])
  return ops.control_dependencies([assert_increasing])


def _check_input_types(y0, t, dt=None):
  if not (y0.dtype.is_floating or y0.dtype.is_complex):
    raise TypeError('`y0` must have a floating point or complex floating '
                    'point dtype')
  if not t.dtype.is_floating:
    raise TypeError('`t` must have a floating point dtype')

  if dt is not None and not dt.dtype.is_floating:
    raise TypeError('`dt` must have a floating point dtype')


def _check_input_sizes(t, dt):
  if len(t.get_shape().as_list()) > 1:
    raise ValueError('t must be a 1D tensor')

  if len(dt.get_shape().as_list()) > 1:
    raise ValueError('t must be a 1D tensor')

  if t.get_shape()[0] != dt.get_shape()[0] + 1:
    raise ValueError('t and dt have incompatible lengths, must be N and N-1')


def _dopri5(func,
            y0,
            t,
            rtol,
            atol,
            full_output=False,
            first_step=None,
            safety=0.9,
            ifactor=10.0,
            dfactor=0.2,
            max_num_steps=1000,
            name=None):
  """Solve an ODE for `odeint` using method='dopri5'."""

  if first_step is None:
    # at some point, we might want to switch to picking the step size
    # automatically
    first_step = 1.0

  with ops.name_scope(name, 'dopri5', [
      y0, t, rtol, atol, safety, ifactor, dfactor, max_num_steps
  ]) as scope:

    first_step = ops.convert_to_tensor(
        first_step, dtype=t.dtype, name='first_step')
    safety = ops.convert_to_tensor(safety, dtype=t.dtype, name='safety')
    ifactor = ops.convert_to_tensor(ifactor, dtype=t.dtype, name='ifactor')
    dfactor = ops.convert_to_tensor(dfactor, dtype=t.dtype, name='dfactor')
    max_num_steps = ops.convert_to_tensor(
        max_num_steps, dtype=dtypes.int32, name='max_num_steps')

    def adaptive_runge_kutta_step(rk_state, history, n_steps):
      """Take an adaptive Runge-Kutta step to integrate the ODE."""
      y0, f0, _, t0, dt, interp_coeff = rk_state
      with ops.name_scope('assertions'):
        check_underflow = control_flow_ops.Assert(t0 + dt > t0,
                                                  ['underflow in dt', dt])
        check_max_num_steps = control_flow_ops.Assert(
            n_steps < max_num_steps, ['max_num_steps exceeded'])
        check_numerics = control_flow_ops.Assert(
            math_ops.reduce_all(math_ops.is_finite(abs(y0))),
            ['non-finite values in state `y`', y0])
      with ops.control_dependencies(
          [check_underflow, check_max_num_steps, check_numerics]):
        y1, f1, y1_error, k = _runge_kutta_step(func, y0, f0, t0, dt)

      with ops.name_scope('error_ratio'):
        # We use the same approach as the dopri5 fortran code.
        error_tol = atol + rtol * math_ops.maximum(abs(y0), abs(y1))
        tensor_error_ratio = _abs_square(y1_error) / _abs_square(error_tol)
        # Could also use reduce_maximum here.
        error_ratio = math_ops.sqrt(math_ops.reduce_mean(tensor_error_ratio))
        accept_step = error_ratio <= 1

      with ops.name_scope('update/rk_state'):
        # If we don't accept the step, the _RungeKuttaState will be useless
        # (covering a time-interval of size 0), but that's OK, because in such
        # cases we always immediately take another Runge-Kutta step.
        y_next = control_flow_ops.cond(accept_step, lambda: y1, lambda: y0)
        f_next = control_flow_ops.cond(accept_step, lambda: f1, lambda: f0)
        t_next = control_flow_ops.cond(accept_step, lambda: t0 + dt, lambda: t0)
        interp_coeff = control_flow_ops.cond(
            accept_step, lambda: _interp_fit_rk(y0, y1, k, dt),
            lambda: interp_coeff)
        dt_next = _optimal_step_size(dt, error_ratio, safety, ifactor, dfactor)
        rk_state = _RungeKuttaState(y_next, f_next, t0, t_next, dt_next,
                                    interp_coeff)

      with ops.name_scope('update/history'):
        history = _History(
            _ta_append(history.integrate_points, t0 + dt),
            _ta_append(history.error_ratio, error_ratio))
      return rk_state, history, n_steps + 1

    def interpolate(solution, history, rk_state, i):
      """Interpolate through the next time point, integrating as necessary."""
      with ops.name_scope('interpolate'):
        rk_state, history, _ = control_flow_ops.while_loop(
            lambda rk_state, *_: t[i] > rk_state.t1,
            adaptive_runge_kutta_step, (rk_state, history, 0),
            name='integrate_loop')
        y = _interp_evaluate(rk_state.interp_coeff, rk_state.t0, rk_state.t1,
                             t[i])
        solution = solution.write(i, y)
        return solution, history, rk_state, i + 1

    with _assert_increasing(t):
      num_times = array_ops.size(t)

    solution = tensor_array_ops.TensorArray(
        y0.dtype, size=num_times).write(0, y0)
    history = _History(
        integrate_points=tensor_array_ops.TensorArray(
            t.dtype, size=0, dynamic_size=True),
        error_ratio=tensor_array_ops.TensorArray(
            rtol.dtype, size=0, dynamic_size=True))
    rk_state = _RungeKuttaState(
        y0, func(y0, t[0]), t[0], t[0], first_step, interp_coeff=[y0] * 5)

    solution, history, _, _ = control_flow_ops.while_loop(
        lambda _, __, ___, i: i < num_times,
        interpolate, (solution, history, rk_state, 1),
        name='interpolate_loop')

    y = solution.stack(name=scope)
    y.set_shape(t.get_shape().concatenate(y0.get_shape()))
    if not full_output:
      return y
    else:
      integrate_points = history.integrate_points.stack()
      info_dict = {
          'num_func_evals': 6 * array_ops.size(integrate_points) + 1,
          'integrate_points': integrate_points,
          'error_ratio': history.error_ratio.stack()
      }
      return (y, info_dict)


def odeint(func,
           y0,
           t,
           rtol=1e-6,
           atol=1e-12,
           method=None,
           options=None,
           full_output=False,
           name=None):
  """Integrate a system of ordinary differential equations.

  Solves the initial value problem for a non-stiff system of first order ODEs:

    ```
    dy/dt = func(y, t), y(t[0]) = y0
    ```

  where y is a Tensor of any shape.

  For example:

    ```
    # solve `dy/dt = -y`, corresponding to exponential decay
    tf.contrib.integrate.odeint(lambda y, _: -y, 1.0, [0, 1, 2])
    => [1, exp(-1), exp(-2)]
    ```

  Output dtypes and numerical precision are based on the dtypes of the inputs
  `y0` and `t`.

  Currently, implements 5th order Runge-Kutta with adaptive step size control
  and dense output, using the Dormand-Prince method. Similar to the 'dopri5'
  method of `scipy.integrate.ode` and MATLAB's `ode45`.

  Based on: Shampine, Lawrence F. (1986), "Some Practical Runge-Kutta Formulas",
  Mathematics of Computation, American Mathematical Society, 46 (173): 135-150,
  doi:10.2307/2008219

  Args:
    func: Function that maps a Tensor holding the state `y` and a scalar Tensor
      `t` into a Tensor of state derivatives with respect to time.
    y0: N-D Tensor giving starting value of `y` at time point `t[0]`. May
      have any floating point or complex dtype.
    t: 1-D Tensor holding a sequence of time points for which to solve for
      `y`. The initial time point should be the first element of this sequence,
      and each time must be larger than the previous time. May have any floating
      point dtype. If not provided as a Tensor, converted to a Tensor with
      float64 dtype.
    rtol: optional float64 Tensor specifying an upper bound on relative error,
      per element of `y`.
    atol: optional float64 Tensor specifying an upper bound on absolute error,
      per element of `y`.
    method: optional string indicating the integration method to use. Currently,
      the only valid option is `'dopri5'`.
    options: optional dict of configuring options for the indicated integration
      method. Can only be provided if a `method` is explicitly set. For
      `'dopri5'`, valid options include:
      * first_step: an initial guess for the size of the first integration
        (current default: 1.0, but may later be changed to use heuristics based
        on the gradient).
      * safety: safety factor for adaptive step control, generally a constant
        in the range 0.8-1 (default: 0.9).
      * ifactor: maximum factor by which the adaptive step may be increased
        (default: 10.0).
      * dfactor: maximum factor by which the adpative step may be decreased
        (default: 0.2).
      * max_num_steps: integer maximum number of integrate steps between time
        points in `t` (default: 1000).
    full_output: optional boolean. If True, `odeint` returns a tuple
      `(y, info_dict)` describing the integration process.
    name: Optional name for this operation.

  Returns:
    y: (N+1)-D tensor, where the first dimension corresponds to different
      time points. Contains the solved value of y for each desired time point in
      `t`, with the initial value `y0` being the first element along the first
      dimension.
    info_dict: only if `full_output == True`. A dict with the following values:
      * num_func_evals: integer Tensor counting the number of function
        evaluations.
      * integrate_points: 1D float64 Tensor with the upper bound of each
        integration time step.
      * error_ratio: 1D float Tensor with the estimated ratio of the integration
        error to the error tolerance at each integration step. An ratio greater
        than 1 corresponds to rejected steps.

  Raises:
    ValueError: if an invalid `method` is provided.
    TypeError: if `options` is supplied without `method`, or if `t` or `y0` has
      an invalid dtype.
  """
  if method is not None and method != 'dopri5':
    raise ValueError('invalid method: %r' % method)

  if options is None:
    options = {}
  elif method is None:
    raise ValueError('cannot supply `options` without specifying `method`')

  with ops.name_scope(name, 'odeint', [y0, t, rtol, atol]) as scope:
    # TODO(shoyer): use nest.flatten (like tf.while_loop) to allow `y0` to be an
    # arbitrarily nested tuple. This will help performance and usability by
    # avoiding the need to pack/unpack in user functions.
    y0 = ops.convert_to_tensor(y0, name='y0')
    t = ops.convert_to_tensor(t, preferred_dtype=dtypes.float64, name='t')
    _check_input_types(y0, t)

    error_dtype = abs(y0).dtype
    rtol = ops.convert_to_tensor(rtol, dtype=error_dtype, name='rtol')
    atol = ops.convert_to_tensor(atol, dtype=error_dtype, name='atol')

    return _dopri5(
        func,
        y0,
        t,
        rtol=rtol,
        atol=atol,
        full_output=full_output,
        name=scope,
        **options)


class _FixedGridIntegrator(six.with_metaclass(abc.ABCMeta)):
  """Base class for fixed-grid ODE integrators."""

  def integrate(self, evol_func, y0, time_grid, dt_grid, steps_on_intervals):
    """Returns integrated values of differential equation on the `time grid`.

    Numerically integrates differential equation defined via time derivative
    evaluator `evol_func` using fixed time steps specified in dt_grid.

    Args:
      evol_func: Callable, evaluates time derivative of y at a given time.
      y0: N-D Tensor holds initial values of the solution.
      time_grid: 1-D Tensor holding the time points at which the solution
        will be recorded, must have a floating dtype.
      dt_grid: 1-D Tensor holds fixed time steps to be used on time_grid
        intervals. Must be a floating dtype and have one less element than that
        of the time_grid.
      steps_on_intervals: 1-D Tensor of integer dtype, must have the same size
        as dt_grid. Specifies number of steps needed for every interval. Assumes
        steps_on_intervals * dt_grid == time intervals.

    Returns:
      (N+1)-D tensor, where the first dimension corresponds to different
      time points. Contains the solved value of y for each desired time point in
      `t`, with the initial value `y0` being the first element along the first
      dimension.
    """

    iteration_func = self._make_iteration_func(evol_func, dt_grid)
    integrate_interval = self._make_interval_integrator(iteration_func,
                                                        steps_on_intervals)

    num_times = array_ops.size(time_grid)
    current_time = time_grid[0]
    solution_array = tensor_array_ops.TensorArray(y0.dtype, num_times)
    solution_array = solution_array.write(0, y0)

    solution_array, _, _, _ = control_flow_ops.while_loop(
        lambda _, __, ___, i: i < num_times,
        integrate_interval,
        (solution_array, y0, current_time, 1)
    )
    solution_array = solution_array.stack()
    solution_array.set_shape(time_grid.get_shape().concatenate(y0.get_shape()))
    return solution_array

  def _make_iteration_func(self, evol_func, dt_grid):
    """Returns a function that builds operations of a single time step."""

    def iteration_func(y, t, dt_step, interval_step):
      """Performs a single time step advance."""
      dt = dt_grid[interval_step - 1]
      dy = self._step_func(evol_func, t, dt, y)
      dy = math_ops.cast(dy, dtype=y.dtype)
      return y + dy, t + dt, dt_step + 1, interval_step

    return iteration_func

  def _make_interval_integrator(self, iteration_func, interval_sizes):
    """Returns a function that builds operations for interval integration."""

    def integrate_interval(solution_array, y, t, interval_num):
      """Integrates y with fixed time step on interval `interval_num`."""
      y, t, _, _ = control_flow_ops.while_loop(
          lambda _, __, j, interval_num: j < interval_sizes[interval_num - 1],
          iteration_func,
          (y, t, 0, interval_num)
      )
      return solution_array.write(interval_num, y), y, t, interval_num + 1

    return integrate_interval

  @abc.abstractmethod
  def _step_func(self, evol_func, t, dt, y):
    pass


class _MidpointFixedGridIntegrator(_FixedGridIntegrator):
  """Fixed grid integrator implementing midpoint scheme."""

  def _step_func(self, evol_func, t, dt, y):
    dt_cast = math_ops.cast(dt, y.dtype)
    # yn1 = yn + h * f(tn + h/2, yn + f(tn, yn) * h/2)
    return dt_cast * evol_func(y + evol_func(y, t) * dt_cast / 2, t + dt / 2)


class _RK4FixedGridIntegrator(_FixedGridIntegrator):
  """Fixed grid integrator implementing RK4 scheme."""

  def _step_func(self, evol_func, t, dt, y):
    k1 = evol_func(y, t)
    half_step = t + dt / 2
    dt_cast = math_ops.cast(dt, y.dtype)

    k2 = evol_func(y + dt_cast * k1 / 2, half_step)
    k3 = evol_func(y + dt_cast * k2 / 2, half_step)
    k4 = evol_func(y + dt_cast * k3, t + dt)
    return math_ops.add_n([k1, 2 * k2, 2 * k3, k4]) * (dt_cast / 6)


def odeint_fixed(func, y0, t, dt=None, method='rk4', name=None):
  """ODE integration on a fixed grid (with no step size control).

  Useful in certain scenarios to avoid the overhead of adaptive step size
  control, e.g. when differentiation of the integration result is desired and/or
  the time grid is known a priori to be sufficient.

  Args:
    func: Function that maps a Tensor holding the state `y` and a scalar Tensor
      `t` into a Tensor of state derivatives with respect to time.
    y0: N-D Tensor giving starting value of `y` at time point `t[0]`.
    t: 1-D Tensor holding a sequence of time points for which to solve for
      `y`. The initial time point should be the first element of this sequence,
      and each time must be larger than the previous time. May have any floating
      point dtype.
    dt: 0-D or 1-D Tensor providing time step suggestion to be used on time
      integration intervals in `t`. 1-D Tensor should provide values
      for all intervals, must have 1 less element than that of `t`.
      If given a 0-D Tensor, the value is interpreted as time step suggestion
      same for all intervals. If passed None, then time step is set to be the
      t[1:] - t[:-1]. Defaults to None. The actual step size is obtained by
      insuring an integer number of steps per interval, potentially reducing the
      time step.
    method: One of 'midpoint' or 'rk4'.
    name: Optional name for the resulting operation.

  Returns:
    y: (N+1)-D tensor, where the first dimension corresponds to different
      time points. Contains the solved value of y for each desired time point in
      `t`, with the initial value `y0` being the first element along the first
      dimension.

  Raises:
    ValueError: Upon caller errors.
  """
  with ops.name_scope(name, 'odeint_fixed', [y0, t, dt]):
    t = ops.convert_to_tensor(t, preferred_dtype=dtypes.float64, name='t')
    y0 = ops.convert_to_tensor(y0, name='y0')

    intervals = t[1:] - t[:-1]
    if dt is None:
      dt = intervals
    dt = ops.convert_to_tensor(dt, preferred_dtype=dtypes.float64, name='dt')

    steps_on_intervals = math_ops.ceil(intervals / dt)
    dt = intervals / steps_on_intervals
    steps_on_intervals = math_ops.cast(steps_on_intervals, dtype=dtypes.int32)

    _check_input_types(y0, t, dt)
    _check_input_sizes(t, dt)

    with _assert_increasing(t):
      with ops.name_scope(method):
        if method == 'midpoint':
          return _MidpointFixedGridIntegrator().integrate(func, y0, t, dt,
                                                          steps_on_intervals)
        elif method == 'rk4':
          return _RK4FixedGridIntegrator().integrate(func, y0, t, dt,
                                                     steps_on_intervals)
        else:
          raise ValueError('method not supported: {!s}'.format(method))