aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/contrib/timeseries/python/timeseries/model.py
blob: f2ef8d22114be50a10d3b106be5e144cc70b4bfc (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
584
585
586
587
588
589
590
591
592
593
594
595
596
597
598
599
600
601
602
603
604
605
606
607
608
609
610
611
612
613
614
615
616
617
618
619
620
621
622
623
624
625
626
627
628
629
630
631
632
633
634
635
636
637
638
639
640
641
642
643
644
645
646
647
648
649
650
651
652
653
654
655
656
657
658
659
660
661
662
663
664
665
666
667
668
669
670
671
672
673
674
675
676
677
678
679
680
681
682
683
684
685
686
687
688
689
690
691
692
693
694
695
696
697
698
699
700
701
702
703
704
705
706
707
708
709
710
711
712
713
714
715
716
717
718
719
720
721
722
723
724
725
726
727
728
729
730
731
732
733
734
735
736
# Copyright 2017 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Base class for time series models."""

from __future__ import absolute_import
from __future__ import division
from __future__ import print_function

import abc
import collections

from tensorflow.contrib import layers

from tensorflow.contrib.timeseries.python.timeseries import math_utils
from tensorflow.contrib.timeseries.python.timeseries.feature_keys import PredictionFeatures
from tensorflow.contrib.timeseries.python.timeseries.feature_keys import TrainEvalFeatures

from tensorflow.python.framework import dtypes
from tensorflow.python.framework import ops
from tensorflow.python.ops import array_ops
from tensorflow.python.ops import control_flow_ops
from tensorflow.python.ops import math_ops
from tensorflow.python.ops import tensor_array_ops
from tensorflow.python.ops import variable_scope

from tensorflow.python.util import nest


ModelOutputs = collections.namedtuple(  # pylint: disable=invalid-name
    typename="ModelOutputs",
    field_names=[
        "loss",  # The scalar value to be minimized during training.
        "end_state",  # A nested tuple specifying the model's state after
                      # running on the specified data
        "predictions",  # A dictionary of predictions, each with shape prefixed
                        # by the shape of `prediction_times`.
        "prediction_times"  # A [batch size x window size] integer Tensor
                            # indicating times for which values in `predictions`
                            # were computed.
    ])


class TimeSeriesModel(object):
  """Base class for creating generative time series models."""

  __metaclass__ = abc.ABCMeta

  def __init__(self,
               num_features,
               exogenous_feature_columns=None,
               dtype=dtypes.float32):
    """Constructor for generative models.

    Args:
      num_features: Number of features for the time series
      exogenous_feature_columns: A list of tf.contrib.layers.FeatureColumn
          objects (for example tf.contrib.layers.embedding_column) corresponding
          to exogenous features which provide extra information to the model but
          are not part of the series to be predicted. Passed to
          tf.contrib.layers.input_from_feature_columns.
      dtype: The floating point datatype to use.
    """
    if exogenous_feature_columns:
      self._exogenous_feature_columns = exogenous_feature_columns
    else:
      self._exogenous_feature_columns = []
    self.num_features = num_features
    self.dtype = dtype
    self._input_statistics = None
    self._graph_initialized = False

  # TODO(allenl): Move more of the generic machinery for generating and
  # predicting into TimeSeriesModel, and possibly share it between generate()
  # and predict()
  def generate(self, number_of_series, series_length,
               model_parameters=None, seed=None):
    """Sample synthetic data from model parameters, with optional substitutions.

    Returns `number_of_series` possible sequences of future values, sampled from
    the generative model with each conditioned on the previous. Samples are
    based on trained parameters, except for those parameters explicitly
    overridden in `model_parameters`.

    For distributions over future observations, see predict().

    Args:
      number_of_series: Number of time series to create.
      series_length: Length of each time series.
      model_parameters: A dictionary mapping model parameters to values, which
          replace trained parameters when generating data.
      seed: If specified, return deterministic time series according to this
          value.
    Returns:
      A dictionary with keys TrainEvalFeatures.TIMES (mapping to an array with
      shape [number_of_series, series_length]) and TrainEvalFeatures.VALUES
      (mapping to an array with shape [number_of_series, series_length,
      num_features]).
    """
    raise NotImplementedError("This model does not support generation.")

  def initialize_graph(self, input_statistics=None):
    """Define ops for the model, not depending on any previously defined ops.

    Args:
      input_statistics: A math_utils.InputStatistics object containing input
          statistics. If None, data-independent defaults are used, which may
          result in longer or unstable training.
    """
    self._graph_initialized = True
    self._input_statistics = input_statistics

  def _check_graph_initialized(self):
    if not self._graph_initialized:
      raise ValueError(
          "TimeSeriesModels require initialize_graph() to be called before "
          "use. This defines variables and ops in the default graph, and "
          "allows Tensor-valued input statistics to be specified.")

  def define_loss(self, features, mode):
    """Default loss definition with state replicated across a batch.

    Time series passed to this model have a batch dimension, and each series in
    a batch can be operated on in parallel. This loss definition assumes that
    each element of the batch represents an independent sample conditioned on
    the same initial state (i.e. it is simply replicated across the batch). A
    batch size of one provides sequential operations on a single time series.

    More complex processing may operate instead on get_start_state() and
    get_batch_loss() directly.

    Args:
      features: A dictionary (such as is produced by a chunker) with at minimum
        the following key/value pairs (others corresponding to the
        `exogenous_feature_columns` argument to `__init__` may be included
        representing exogenous regressors):
        TrainEvalFeatures.TIMES: A [batch size x window size] integer Tensor
            with times for each observation. If there is no artificial chunking,
            the window size is simply the length of the time series.
        TrainEvalFeatures.VALUES: A [batch size x window size x num features]
            Tensor with values for each observation.
      mode: The tf.estimator.ModeKeys mode to use (TRAIN, EVAL). For INFER,
        see predict().
    Returns:
      A ModelOutputs object.
    """
    self._check_graph_initialized()
    start_state = math_utils.replicate_state(
        start_state=self.get_start_state(),
        batch_size=array_ops.shape(features[TrainEvalFeatures.TIMES])[0])
    return self.get_batch_loss(features=features, mode=mode, state=start_state)

  # TODO(vitalyk,allenl): Better documentation surrounding options for chunking,
  # references to papers, etc.
  @abc.abstractmethod
  def get_start_state(self):
    """Returns a tuple of state for the start of the time series.

    For example, a mean and covariance. State should not have a batch
    dimension, and will often be TensorFlow Variables to be learned along with
    the rest of the model parameters.
    """
    pass

  @abc.abstractmethod
  def get_batch_loss(self, features, mode, state):
    """Return predictions, losses, and end state for a time series.

    Args:
      features: A dictionary with times, values, and (optionally) exogenous
          regressors. See `define_loss`.
      mode: The tf.estimator.ModeKeys mode to use (TRAIN, EVAL, INFER).
      state: Model-dependent state, each with size [batch size x ...]. The
          number and type will typically be fixed by the model (for example a
          mean and variance).
    Returns:
      A ModelOutputs object.
    """
    pass

  @abc.abstractmethod
  def predict(self, features):
    """Returns predictions of future observations given an initial state.

    Computes distributions for future observations. For sampled draws from the
    model where each is conditioned on the previous, see generate().

    Args:
      features: A dictionary with at minimum the following key/value pairs
        (others corresponding to the `exogenous_feature_columns` argument to
        `__init__` may be included representing exogenous regressors):
        PredictionFeatures.TIMES: A [batch size x window size] Tensor with
          times to make predictions for. Times must be increasing within each
          part of the batch, and must be greater than the last time `state` was
          updated.
        PredictionFeatures.STATE_TUPLE: Model-dependent state, each with size
          [batch size x ...]. The number and type will typically be fixed by the
          model (for example a mean and variance). Typically these will be the
          end state returned by get_batch_loss, predicting beyond that data.
    Returns:
      A dictionary with model-dependent predictions corresponding to the
      requested times. Keys indicate the type of prediction, and values have
      shape [batch size x window size x ...]. For example state space models
      return a "predicted_mean" and "predicted_covariance".
    """
    pass

  def _process_exogenous_features(self, times, features):
    """Create a single vector from exogenous features.

    Args:
      times: A [batch size, window size] vector of times for this batch,
          primarily used to check the shape information of exogenous features.
      features: A dictionary of exogenous features corresponding to the columns
          in self._exogenous_feature_columns. Each value should have a shape
          prefixed by [batch size, window size].
    Returns:
      A Tensor with shape [batch size, window size, exogenous dimension], where
      the size of the exogenous dimension depends on the exogenous feature
      columns passed to the model's constructor.
    Raises:
      ValueError: If an exogenous feature has an unknown rank.
    """
    if self._exogenous_feature_columns:
      exogenous_features_single_batch_dimension = {}
      for name, tensor in features.items():
        if tensor.get_shape().ndims is None:
          # input_from_feature_columns does not support completely unknown
          # feature shapes, so we save on a bit of logic and provide a better
          # error message by checking that here.
          raise ValueError(
              ("Features with unknown rank are not supported. Got shape {} for "
               "feature {}.").format(tensor.get_shape(), name))
        tensor_shape_dynamic = array_ops.shape(tensor)
        tensor = array_ops.reshape(
            tensor,
            array_ops.concat([[tensor_shape_dynamic[0]
                               * tensor_shape_dynamic[1]],
                              tensor_shape_dynamic[2:]], axis=0))
        # Avoid shape warnings when embedding "scalar" exogenous features (those
        # with only batch and window dimensions); input_from_feature_columns
        # expects input ranks to match the embedded rank.
        if tensor.get_shape().ndims == 1:
          exogenous_features_single_batch_dimension[name] = tensor[:, None]
        else:
          exogenous_features_single_batch_dimension[name] = tensor
      embedded_exogenous_features_single_batch_dimension = (
          layers.input_from_feature_columns(
              columns_to_tensors=exogenous_features_single_batch_dimension,
              feature_columns=self._exogenous_feature_columns,
              trainable=True))
      exogenous_regressors = array_ops.reshape(
          embedded_exogenous_features_single_batch_dimension,
          array_ops.concat(
              [
                  array_ops.shape(times), array_ops.shape(
                      embedded_exogenous_features_single_batch_dimension)[1:]
              ],
              axis=0))
      exogenous_regressors.set_shape(times.get_shape().concatenate(
          embedded_exogenous_features_single_batch_dimension.get_shape()[1:]))
      exogenous_regressors = math_ops.cast(
          exogenous_regressors, dtype=self.dtype)
    else:
      # Not having any exogenous features is a special case so that models can
      # avoid superfluous updates, which may not be free of side effects due to
      # bias terms in transformations.
      exogenous_regressors = None
    return exogenous_regressors


# TODO(allenl): Add a superclass of SequentialTimeSeriesModel which fuses
# filtering/prediction/exogenous into one step, and move looping constructs to
# that class.
class SequentialTimeSeriesModel(TimeSeriesModel):
  """Base class for recurrent generative models.

  Models implementing this interface have three main functions, corresponding to
  abstract methods:
    _filtering_step: Updates state based on observations and computes a loss.
    _prediction_step: Predicts a batch of observations and new model state.
    _imputation_step: Updates model state across a gap.
    _exogenous_input_step: Updates state to account for exogenous regressors.

  Models may also specify a _window_initializer to prepare for a window of data.

  See StateSpaceModel for a concrete example of a model implementing this
  interface.

  """

  def __init__(self,
               train_output_names,
               predict_output_names,
               num_features,
               dtype=dtypes.float32,
               exogenous_feature_columns=None,
               exogenous_update_condition=None,
               static_unrolling_window_size_threshold=None):
    """Initialize a SequentialTimeSeriesModel.

    Args:
      train_output_names: A list of products/predictions returned from
          _filtering_step.
      predict_output_names: A list of products/predictions returned from
          _prediction_step.
      num_features: Number of features for the time series
      dtype: The floating point datatype to use.
      exogenous_feature_columns: A list of tf.contrib.layers.FeatureColumn
          objects. See `TimeSeriesModel`.
      exogenous_update_condition: A function taking two Tensor arguments `times`
          (shape [batch size]) and `features` (a dictionary mapping exogenous
          feature keys to Tensors with shapes [batch size, ...]) and returning a
          boolean Tensor with shape [batch size] indicating whether state should
          be updated using exogenous features for each part of the batch. Where
          it is False, no exogenous update is performed. If None (default),
          exogenous updates are always performed. Useful for avoiding "leaky"
          frequent exogenous updates when sparse updates are desired. Called
          only during graph construction.
      static_unrolling_window_size_threshold: Controls whether a `tf.while_loop`
          is used when looping over a window of data. If
          `static_unrolling_window_size_threshold` is None, a `tf.while_loop` is
          always used. Otherwise it must be an integer, and the graph is
          replicated for each step taken whenever the window size is less than
          or equal to this value (if the window size is available in the static
          shape information of the TrainEvalFeatures.TIMES feature). Static
          unrolling generally decreases the per-step time for small window/batch
          sizes, but increases graph construction time.
    """
    super(SequentialTimeSeriesModel, self).__init__(
        num_features=num_features, dtype=dtype,
        exogenous_feature_columns=exogenous_feature_columns)
    self._exogenous_update_condition = exogenous_update_condition
    self._train_output_names = train_output_names
    self._predict_output_names = predict_output_names
    self._static_unrolling_window_size_threshold = (
        static_unrolling_window_size_threshold)

  @abc.abstractmethod
  def _filtering_step(self, current_times, current_values, state, predictions):
    """Compute a single-step loss for a batch of data.

    Args:
      current_times: A [batch size] Tensor of times for each observation.
      current_values: A [batch size] Tensor of values for each observation.
      state: Model state, updated to current_times.
      predictions: The outputs of _prediction_step
    Returns:
      A tuple of (updated state, outputs):
        updated state: Model state taking current_values into account.
        outputs: A dictionary of Tensors with keys corresponding to
            self._train_output_names, plus a special "loss" key. The value
            corresponding to "loss" is minimized during training. Other outputs
            may include one-step-ahead predictions, for example a predicted
            location and scale.
    """
    pass

  @abc.abstractmethod
  def _prediction_step(self, current_times, state):
    """Compute a batch of single-step predictions.

    Args:
      current_times: A [batch size] Tensor of times for each observation.
      state: Model state, imputed to one step before current_times.
    Returns:
      A tuple of (updated state, outputs):
        updated state: Model state updated to current_times.
        outputs: A dictionary of Tensors with keys corresponding to
            self._predict_output_names.
    """
    pass

  @abc.abstractmethod
  def _imputation_step(self, current_times, state):
    """Update model state across missing values.

    Called to prepare model state for _filtering_step and _prediction_step.

    Args:
      current_times: A [batch size] Tensor; state will be imputed up to, but not
          including, these timesteps.
      state: The pre-imputation model state, Tensors with shape [batch size x
          ...].
    Returns:
      Updated/imputed model state, corresponding to `state`.
    """
    pass

  @abc.abstractmethod
  def _exogenous_input_step(
      self, current_times, current_exogenous_regressors, state):
    """Update state to account for exogenous regressors.

    Args:
      current_times: A [batch size] Tensor of times for the exogenous values
          being input.
      current_exogenous_regressors: A [batch size x exogenous input dimension]
          Tensor of exogenous values for each part of the batch.
      state: Model state, a possibly nested list of Tensors, each with shape
          [batch size x ...].
    Returns:
      Updated model state, structure and shapes matching the `state` argument.
    """
    pass

  # TODO(allenl): Move regularization to a separate object (optional and
  # configurable)
  def _loss_additions(self, times, values, mode):
    """Additions to per-observation normalized loss, e.g. regularization.

    Args:
      times: A [batch size x window size] Tensor with times for each
          observation.
      values: A [batch size x window size x num features] Tensor with values for
          each observation.
      mode: The tf.estimator.ModeKeys mode to use (TRAIN, EVAL, INFER).
    Returns:
      A scalar value to add to the per-observation normalized loss.
    """
    del times, values, mode
    return 0.

  def _window_initializer(self, times, state):
    """Prepare for training or prediction on a window of data.

    Args:
      times: A [batch size x window size] Tensor with times for each
          observation.
      state: Model-dependent state, each with size [batch size x ...]. The
          number and type will typically be fixed by the model (for example a
          mean and variance).
    Returns:
      Nothing
    """
    pass

  def get_batch_loss(self, features, mode, state):
    """Calls self._filtering_step. See TimeSeriesModel.get_batch_loss."""
    per_observation_loss, state, outputs = self.per_step_batch_loss(
        features, mode, state)
    # per_step_batch_loss returns [batch size, window size, ...] state, whereas
    # get_batch_loss is expected to return [batch size, ...] state for the last
    # element of a window
    state = nest.pack_sequence_as(
        state,
        [state_element[:, -1] for state_element in nest.flatten(state)])
    outputs["observed"] = features[TrainEvalFeatures.VALUES]
    return ModelOutputs(
        loss=per_observation_loss,
        end_state=state,
        predictions=outputs,
        prediction_times=features[TrainEvalFeatures.TIMES])

  def _apply_exogenous_update(
      self, current_times, step_number, state, raw_features,
      embedded_exogenous_regressors):
    """Performs a conditional state update based on exogenous features."""
    if embedded_exogenous_regressors is None:
      return state
    else:
      current_exogenous_regressors = embedded_exogenous_regressors[
          :, step_number, :]
      exogenous_updated_state = self._exogenous_input_step(
          current_times=current_times,
          current_exogenous_regressors=current_exogenous_regressors,
          state=state)
      if self._exogenous_update_condition is not None:
        current_raw_exogenous_features = {
            key: value[:, step_number] for key, value in raw_features.items()
            if key not in [PredictionFeatures.STATE_TUPLE,
                           TrainEvalFeatures.TIMES,
                           TrainEvalFeatures.VALUES]}
        conditionally_updated_state_flat = []
        for updated_state_element, original_state_element in zip(
            nest.flatten(exogenous_updated_state),
            nest.flatten(state)):
          conditionally_updated_state_flat.append(
              array_ops.where(
                  self._exogenous_update_condition(
                      times=current_times,
                      features=current_raw_exogenous_features),
                  updated_state_element,
                  original_state_element))
        return nest.pack_sequence_as(state, conditionally_updated_state_flat)
      else:
        return exogenous_updated_state

  def per_step_batch_loss(self, features, mode, state):
    """Computes predictions, losses, and intermediate model states.

    Args:
      features: A dictionary with times, values, and (optionally) exogenous
          regressors. See `define_loss`.
      mode: The tf.estimator.ModeKeys mode to use (TRAIN, EVAL, INFER).
      state: Model-dependent state, each with size [batch size x ...]. The
          number and type will typically be fixed by the model (for example a
          mean and variance).
    Returns:
      A tuple of (loss, filtered_states, predictions)
        loss: Average loss values across the batch.
        filtered_states: For each Tensor in `state` with shape [batch size x
            ...], `filtered_states` has a Tensor with shape [batch size x window
            size x ...] with filtered state for each part of the batch and
            window.
        predictions: A dictionary with model-dependent one-step-ahead (or
            at-least-one-step-ahead with missing values) predictions, with keys
            indicating the type of prediction and values having shape [batch
            size x window size x ...]. For example state space models provide
            "mean", "covariance", and "log_likelihood".

    """
    self._check_graph_initialized()
    times = math_ops.cast(features[TrainEvalFeatures.TIMES], dtype=dtypes.int64)
    values = math_ops.cast(features[TrainEvalFeatures.VALUES], dtype=self.dtype)
    exogenous_regressors = self._process_exogenous_features(
        times=times,
        features={key: value for key, value in features.items()
                  if key not in [TrainEvalFeatures.TIMES,
                                 TrainEvalFeatures.VALUES]})
    def _batch_loss_filtering_step(step_number, current_times, state):
      """Make a prediction and update it based on data."""
      current_values = values[:, step_number, :]
      state = self._apply_exogenous_update(
          step_number=step_number, current_times=current_times, state=state,
          raw_features=features,
          embedded_exogenous_regressors=exogenous_regressors)
      predicted_state, predictions = self._prediction_step(
          current_times=current_times,
          state=state)
      filtered_state, outputs = self._filtering_step(
          current_times=current_times,
          current_values=current_values,
          state=predicted_state,
          predictions=predictions)
      return filtered_state, outputs
    state, outputs = self._state_update_loop(
        times=times, state=state, state_update_fn=_batch_loss_filtering_step,
        outputs=["loss"] + self._train_output_names)
    outputs["loss"].set_shape(times.get_shape())
    loss_sum = math_ops.reduce_sum(outputs["loss"])
    per_observation_loss = (loss_sum / math_ops.cast(
        math_ops.reduce_prod(array_ops.shape(times)), dtype=self.dtype))
    per_observation_loss += self._loss_additions(times, values, mode)
    # Since we have window-level additions to the loss, its per-step value is
    # misleading, so we avoid returning it.
    del outputs["loss"]
    return per_observation_loss, state, outputs

  def predict(self, features):
    """Calls self._prediction_step in a loop. See TimeSeriesModel.predict."""
    predict_times = ops.convert_to_tensor(features[PredictionFeatures.TIMES],
                                          dtypes.int64)
    start_state = features[PredictionFeatures.STATE_TUPLE]
    exogenous_regressors = self._process_exogenous_features(
        times=predict_times,
        features={
            key: value
            for key, value in features.items()
            if key not in
            [PredictionFeatures.TIMES, PredictionFeatures.STATE_TUPLE]
        })
    def _call_prediction_step(step_number, current_times, state):
      state = self._apply_exogenous_update(
          step_number=step_number, current_times=current_times, state=state,
          raw_features=features,
          embedded_exogenous_regressors=exogenous_regressors)
      state, outputs = self._prediction_step(
          current_times=current_times, state=state)
      return state, outputs
    _, predictions = self._state_update_loop(
        times=predict_times, state=start_state,
        state_update_fn=_call_prediction_step,
        outputs=self._predict_output_names)
    return predictions

  class _FakeTensorArray(object):
    """An interface for Python lists that is similar to TensorArray.

    Used for easy switching between static and dynamic looping.
    """

    def __init__(self):
      self.values = []

    def write(self, unused_position, value):
      del unused_position
      self.values.append(value)
      return self

  def _state_update_loop(self, times, state, state_update_fn, outputs):
    """Iterates over `times`, calling `state_update_fn` to collect outputs.

    Args:
      times: A [batch size x window size] Tensor of integers to iterate over.
      state: A list of model-specific state Tensors, each with shape [batch size
          x ...].
      state_update_fn: A callback taking the following arguments
            step_number; A scalar integer Tensor indicating the current position
              in the window.
            current_times; A [batch size] vector of Integers indicating times
              for each part of the batch.
            state; Current model state.
          It returns a tuple of (updated state, output_values), output_values
          being a dictionary of Tensors with keys corresponding to `outputs`.
      outputs: A list of strings indicating values which will be saved while
          iterating. Must match the keys of the dictionary returned by
          state_update_fn.
    Returns:
      A tuple of (state, output_dict)
      state: The final model state.
      output_dict: A dictionary of outputs corresponding to those specified in
        `outputs` and computed in state_update_fn.
    """
    times = ops.convert_to_tensor(times, dtype=dtypes.int64)
    window_static_shape = times.get_shape()[1].value
    if self._static_unrolling_window_size_threshold is None:
      static_unroll = False
    else:
      # The user has specified a threshold for static loop unrolling.
      if window_static_shape is None:
        # We don't have static shape information for the window size, so dynamic
        # looping is our only option.
        static_unroll = False
      elif window_static_shape <= self._static_unrolling_window_size_threshold:
        # The threshold is satisfied; unroll statically
        static_unroll = True
      else:
        # A threshold was set but not satisfied
        static_unroll = False

    self._window_initializer(times, state)

    def _run_condition(step_number, *unused):
      del unused  # not part of while loop run condition
      return math_ops.less(step_number, window_size)

    def _state_update_step(
        step_number, state, state_accumulators, output_accumulators,
        reuse=False):
      """Impute, then take one state_update_fn step, accumulating outputs."""
      with variable_scope.variable_scope("state_update_step", reuse=reuse):
        current_times = times[:, step_number]
        state = self._imputation_step(current_times=current_times, state=state)
        output_accumulators_dict = {
            accumulator_key: accumulator
            for accumulator_key, accumulator
            in zip(outputs, output_accumulators)}
        step_state, output_values = state_update_fn(
            step_number=step_number,
            current_times=current_times,
            state=state)
        assert set(output_values.keys()) == set(outputs)
        new_output_accumulators = []
        for output_key in outputs:
          accumulator = output_accumulators_dict[output_key]
          output_value = output_values[output_key]
          new_output_accumulators.append(
              accumulator.write(step_number, output_value))
        flat_step_state = nest.flatten(step_state)
        assert len(state_accumulators) == len(flat_step_state)
        new_state_accumulators = []
        new_state_flat = []
        for step_state_value, state_accumulator, original_state in zip(
            flat_step_state, state_accumulators, nest.flatten(state)):
          # Make sure the static shape information is complete so while_loop
          # does not complain about shape information changing.
          step_state_value.set_shape(original_state.get_shape())
          new_state_flat.append(step_state_value)
          new_state_accumulators.append(state_accumulator.write(
              step_number, step_state_value))
        step_state = nest.pack_sequence_as(state, new_state_flat)
        return (step_number + 1, step_state,
                new_state_accumulators, new_output_accumulators)

    window_size = array_ops.shape(times)[1]

    def _window_size_tensor_array(dtype):
      if static_unroll:
        return self._FakeTensorArray()
      else:
        return tensor_array_ops.TensorArray(
            dtype=dtype, size=window_size, dynamic_size=False)

    initial_loop_arguments = [
        array_ops.zeros([], dtypes.int32),
        state,
        [_window_size_tensor_array(element.dtype)
         for element in nest.flatten(state)],
        [_window_size_tensor_array(self.dtype) for _ in outputs]]
    if static_unroll:
      arguments = initial_loop_arguments
      for step_number in range(times.get_shape()[1].value):
        arguments = _state_update_step(
            array_ops.constant(step_number, dtypes.int32), *arguments[1:],
            reuse=(step_number > 0))  # Variable sharing between steps
    else:
      arguments = control_flow_ops.while_loop(
          cond=_run_condition,
          body=_state_update_step,
          loop_vars=initial_loop_arguments)
    (_, _, state_loop_result, outputs_loop_result) = arguments

    def _stack_and_transpose(tensor_array):
      """Stack and re-order the dimensions of a TensorArray."""
      if static_unroll:
        return array_ops.stack(tensor_array.values, axis=1)
      else:
        # TensorArrays from while_loop stack with window size as the first
        # dimension, so this function swaps it and the batch dimension to
        # maintain the [batch x window size x ...] convention used elsewhere.
        stacked = tensor_array.stack()
        return array_ops.transpose(
            stacked,
            perm=array_ops.concat([[1, 0], math_ops.range(
                2, array_ops.rank(stacked))], 0))

    outputs_dict = {output_key: _stack_and_transpose(output)
                    for output_key, output
                    in zip(outputs, outputs_loop_result)}
    full_state = nest.pack_sequence_as(
        state,
        [_stack_and_transpose(state_element)
         for state_element in state_loop_result])
    return full_state, outputs_dict