aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/contrib/timeseries
diff options
context:
space:
mode:
authorGravatar A. Unique TensorFlower <gardener@tensorflow.org>2018-09-21 23:45:44 -0700
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2018-09-21 23:49:53 -0700
commit6f0bdfd788ebaaa55c2b4022c70c8bad2cc5dd2c (patch)
treec2aa2b412ac1d4b5c6b43cddde5a6e06385aaddf /tensorflow/contrib/timeseries
parent94b5ff16ce1530e09bc30c51709b0596ff61103f (diff)
Wrap ARModel and LSTMPredictionModel into an LSTMAutoRegressor estimator
PiperOrigin-RevId: 214091820
Diffstat (limited to 'tensorflow/contrib/timeseries')
-rw-r--r--tensorflow/contrib/timeseries/python/timeseries/ar_model.py65
-rw-r--r--tensorflow/contrib/timeseries/python/timeseries/estimators.py157
-rw-r--r--tensorflow/contrib/timeseries/python/timeseries/estimators_test.py35
3 files changed, 242 insertions, 15 deletions
diff --git a/tensorflow/contrib/timeseries/python/timeseries/ar_model.py b/tensorflow/contrib/timeseries/python/timeseries/ar_model.py
index 1d27fffc62..9bbe87e301 100644
--- a/tensorflow/contrib/timeseries/python/timeseries/ar_model.py
+++ b/tensorflow/contrib/timeseries/python/timeseries/ar_model.py
@@ -191,6 +191,43 @@ class ARModel(model.TimeSeriesModel):
Note that this class can also be used to regress against time only by setting
the input_window_size to zero.
+
+ Each periodicity in the `periodicities` arg is divided by the
+ `num_time_buckets` into time buckets that are represented as features added
+ to the model.
+
+ A good heuristic for picking an appropriate periodicity for a given data set
+ would be the length of cycles in the data. For example, energy usage in a
+ home is typically cyclic each day. If the time feature in a home energy
+ usage dataset is in the unit of hours, then 24 would be an appropriate
+ periodicity. Similarly, a good heuristic for `num_time_buckets` is how often
+ the data is expected to change within the cycle. For the aforementioned home
+ energy usage dataset and periodicity of 24, then 48 would be a reasonable
+ value if usage is expected to change every half hour.
+
+ Each feature's value for a given example with time t is the difference
+ between t and the start of the time bucket it falls under. If it doesn't fall
+ under a feature's associated time bucket, then that feature's value is zero.
+
+ For example: if `periodicities` = (9, 12) and `num_time_buckets` = 3, then 6
+ features would be added to the model, 3 for periodicity 9 and 3 for
+ periodicity 12.
+
+ For an example data point where t = 17:
+ - It's in the 3rd time bucket for periodicity 9 (2nd period is 9-18 and 3rd
+ time bucket is 15-18)
+ - It's in the 2nd time bucket for periodicity 12 (2nd period is 12-24 and
+ 2nd time bucket is between 16-20).
+
+ Therefore the 6 added features for this row with t = 17 would be:
+
+ # Feature name (periodicity#_timebucket#), feature value
+ P9_T1, 0 # not in first time bucket
+ P9_T2, 0 # not in second time bucket
+ P9_T3, 2 # 17 - 15 since 15 is the start of the 3rd time bucket
+ P12_T1, 0 # not in first time bucket
+ P12_T2, 1 # 17 - 16 since 16 is the start of the 2nd time bucket
+ P12_T3, 0 # not in third time bucket
"""
SQUARED_LOSS = "squared_loss"
NORMAL_LIKELIHOOD_LOSS = "normal_likelihood_loss"
@@ -208,7 +245,9 @@ class ARModel(model.TimeSeriesModel):
Args:
periodicities: periodicities of the input data, in the same units as the
- time feature. Note this can be a single value or a list of values for
+ time feature (for example 24 if feeding hourly data with a daily
+ periodicity, or 60 * 24 if feeding minute-level data with daily
+ periodicity). Note this can be a single value or a list of values for
multiple periodicities.
input_window_size: Number of past time steps of data to look at when doing
the regression.
@@ -218,21 +257,18 @@ class ARModel(model.TimeSeriesModel):
prediction_model_factory: A callable taking arguments `num_features`,
`input_window_size`, and `output_window_size` and returning a
`tf.keras.Model`. The `Model`'s `call()` takes two arguments: an input
- window and an output window, and returns a dictionary of
- predictions. See `FlatPredictionModel` for an example. Example usage:
+ window and an output window, and returns a dictionary of predictions.
+ See `FlatPredictionModel` for an example. Example usage:
- ```python
- model = ar_model.ARModel(
- periodicities=2, num_features=3,
- prediction_model_factory=functools.partial(
- FlatPredictionModel,
- hidden_layer_sizes=[10, 10]))
- ```
+ ```python model = ar_model.ARModel( periodicities=2, num_features=3,
+ prediction_model_factory=functools.partial( FlatPredictionModel,
+ hidden_layer_sizes=[10, 10])) ```
The default model computes predictions as a linear function of flattened
input and output windows.
num_time_buckets: Number of buckets into which to divide (time %
- periodicity) for generating time based features.
+ periodicity). This value multiplied by the number of periodicities is
+ the number of time features added to the model.
loss: Loss function to use for training. Currently supported values are
SQUARED_LOSS and NORMAL_LIKELIHOOD_LOSS. Note that for
NORMAL_LIKELIHOOD_LOSS, we train the covariance term as well. For
@@ -240,10 +276,9 @@ class ARModel(model.TimeSeriesModel):
observations and predictions, while the training loss is computed on
normalized data (if input statistics are available).
exogenous_feature_columns: A list of `tf.feature_column`s (for example
- `tf.feature_column.embedding_column`) corresponding to exogenous
- features which provide extra information to the model but are not part
- of the series to be predicted. Passed to
- `tf.feature_column.input_layer`.
+ `tf.feature_column.embedding_column`) corresponding to
+ features which provide extra information to the model but are not part
+ of the series to be predicted.
"""
self._model_factory = prediction_model_factory
self.input_window_size = input_window_size
diff --git a/tensorflow/contrib/timeseries/python/timeseries/estimators.py b/tensorflow/contrib/timeseries/python/timeseries/estimators.py
index 0ddc4b4144..af68aa03cf 100644
--- a/tensorflow/contrib/timeseries/python/timeseries/estimators.py
+++ b/tensorflow/contrib/timeseries/python/timeseries/estimators.py
@@ -30,6 +30,7 @@ from tensorflow.contrib.timeseries.python.timeseries.state_space_models import s
from tensorflow.contrib.timeseries.python.timeseries.state_space_models.filtering_postprocessor import StateInterpolatingAnomalyDetector
from tensorflow.python.estimator import estimator_lib
+from tensorflow.python.estimator.canned import optimizers
from tensorflow.python.estimator.export import export_lib
from tensorflow.python.feature_column import feature_column
from tensorflow.python.framework import dtypes
@@ -386,6 +387,162 @@ class ARRegressor(TimeSeriesRegressor):
config=config)
+# TODO(b/113684821): Add detailed documentation on what the input_fn should do.
+# Add an example of making and returning a Dataset object. Determine if
+# endogenous features can be passed in as FeatureColumns. Move ARModel's loss
+# functions into a more general location.
+class LSTMAutoRegressor(TimeSeriesRegressor):
+ """An Estimator for an LSTM autoregressive model.
+
+ LSTMAutoRegressor is a window-based model, inputting fixed windows of length
+ `input_window_size` and outputting fixed windows of length
+ `output_window_size`. These two parameters must add up to the window_size
+ of data returned by the `input_fn`.
+
+ Each periodicity in the `periodicities` arg is divided by the `num_timesteps`
+ into timesteps that are represented as time features added to the model.
+
+ A good heuristic for picking an appropriate periodicity for a given data set
+ would be the length of cycles in the data. For example, energy usage in a
+ home is typically cyclic each day. If the time feature in a home energy
+ usage dataset is in the unit of hours, then 24 would be an appropriate
+ periodicity. Similarly, a good heuristic for `num_timesteps` is how often the
+ data is expected to change within the cycle. For the aforementioned home
+ energy usage dataset and periodicity of 24, then 48 would be a reasonable
+ value if usage is expected to change every half hour.
+
+ Each feature's value for a given example with time t is the difference
+ between t and the start of the timestep it falls under. If it doesn't fall
+ under a feature's associated timestep, then that feature's value is zero.
+
+ For example: if `periodicities` = (9, 12) and `num_timesteps` = 3, then 6
+ features would be added to the model, 3 for periodicity 9 and 3 for
+ periodicity 12.
+
+ For an example data point where t = 17:
+ - It's in the 3rd timestep for periodicity 9 (2nd period is 9-18 and 3rd
+ timestep is 15-18)
+ - It's in the 2nd timestep for periodicity 12 (2nd period is 12-24 and
+ 2nd timestep is between 16-20).
+
+ Therefore the 6 added features for this row with t = 17 would be:
+
+ # Feature name (periodicity#_timestep#), feature value
+ P9_T1, 0 # not in first timestep
+ P9_T2, 0 # not in second timestep
+ P9_T3, 2 # 17 - 15 since 15 is the start of the 3rd timestep
+ P12_T1, 0 # not in first timestep
+ P12_T2, 1 # 17 - 16 since 16 is the start of the 2nd timestep
+ P12_T3, 0 # not in third timestep
+
+ Example Code:
+
+ ```python
+ extra_feature_columns = (
+ feature_column.numeric_column("exogenous_variable"),
+ )
+
+ estimator = LSTMAutoRegressor(
+ periodicities=10,
+ input_window_size=10,
+ output_window_size=5,
+ model_dir="/path/to/model/dir",
+ num_features=1,
+ extra_feature_columns=extra_feature_columns,
+ num_timesteps=50,
+ num_units=10,
+ optimizer=tf.train.ProximalAdagradOptimizer(...))
+
+ # Input builders
+ def input_fn_train():
+ return {
+ "times": tf.range(15)[None, :],
+ "values": tf.random_normal(shape=[1, 15, 1])
+ }
+ estimator.train(input_fn=input_fn_train, steps=100)
+
+ def input_fn_eval():
+ pass
+ metrics = estimator.evaluate(input_fn=input_fn_eval, steps=10)
+
+ def input_fn_predict():
+ pass
+ predictions = estimator.predict(input_fn=input_fn_predict)
+ ```
+ """
+
+ def __init__(self,
+ periodicities,
+ input_window_size,
+ output_window_size,
+ model_dir=None,
+ num_features=1,
+ extra_feature_columns=None,
+ num_timesteps=10,
+ loss=ar_model.ARModel.NORMAL_LIKELIHOOD_LOSS,
+ num_units=128,
+ optimizer="Adam",
+ config=None):
+ """Initialize the Estimator.
+
+ Args:
+ periodicities: periodicities of the input data, in the same units as the
+ time feature (for example 24 if feeding hourly data with a daily
+ periodicity, or 60 * 24 if feeding minute-level data with daily
+ periodicity). Note this can be a single value or a list of values for
+ multiple periodicities.
+ input_window_size: Number of past time steps of data to look at when doing
+ the regression.
+ output_window_size: Number of future time steps to predict. Note that
+ setting this value to > 1 empirically seems to give a better fit.
+ model_dir: Directory to save model parameters, graph and etc. This can
+ also be used to load checkpoints from the directory into a estimator
+ to continue training a previously saved model.
+ num_features: The dimensionality of the time series (default value is
+ one for univariate, more than one for multivariate).
+ extra_feature_columns: A list of `tf.feature_column`s (for example
+ `tf.feature_column.embedding_column`) corresponding to features which
+ provide extra information to the model but are not part of the series to
+ be predicted.
+ num_timesteps: Number of buckets into which to divide (time %
+ periodicity). This value multiplied by the number of periodicities is
+ the number of time features added to the model.
+ loss: Loss function to use for training. Currently supported values are
+ SQUARED_LOSS and NORMAL_LIKELIHOOD_LOSS. Note that for
+ NORMAL_LIKELIHOOD_LOSS, we train the covariance term as well. For
+ SQUARED_LOSS, the evaluation loss is reported based on un-scaled
+ observations and predictions, while the training loss is computed on
+ normalized data.
+ num_units: The size of the hidden state in the encoder and decoder LSTM
+ cells.
+ optimizer: string, `tf.train.Optimizer` object, or callable that defines
+ the optimizer algorithm to use for training. Defaults to the Adam
+ optimizer with a learning rate of 0.01.
+ config: Optional `estimator.RunConfig` object to configure the runtime
+ settings.
+ """
+ optimizer = optimizers.get_optimizer_instance(
+ optimizer, learning_rate=0.01)
+ model = ar_model.ARModel(
+ periodicities=periodicities,
+ input_window_size=input_window_size,
+ output_window_size=output_window_size,
+ num_features=num_features,
+ exogenous_feature_columns=extra_feature_columns,
+ num_time_buckets=num_timesteps,
+ loss=loss,
+ prediction_model_factory=functools.partial(
+ ar_model.LSTMPredictionModel, num_units=num_units))
+ state_manager = state_management.FilteringOnlyStateManager()
+ super(LSTMAutoRegressor, self).__init__(
+ model=model,
+ state_manager=state_manager,
+ optimizer=optimizer,
+ model_dir=model_dir,
+ config=config,
+ head_type=ts_head_lib.OneShotPredictionHead)
+
+
class StateSpaceRegressor(TimeSeriesRegressor):
"""An Estimator for general state space models."""
diff --git a/tensorflow/contrib/timeseries/python/timeseries/estimators_test.py b/tensorflow/contrib/timeseries/python/timeseries/estimators_test.py
index 83260fc59a..6ec7184c68 100644
--- a/tensorflow/contrib/timeseries/python/timeseries/estimators_test.py
+++ b/tensorflow/contrib/timeseries/python/timeseries/estimators_test.py
@@ -226,5 +226,40 @@ class TimeSeriesRegressorTest(test.TestCase):
input_pipeline.NumpyReader(numpy_data)),
steps=1)
+ def test_ar_lstm_regressor(self):
+ dtype = dtypes.float32
+ model_dir = tempfile.mkdtemp(dir=self.get_temp_dir())
+ exogenous_feature_columns = (
+ feature_column.numeric_column("exogenous"),
+ )
+ estimator = estimators.LSTMAutoRegressor(
+ periodicities=10,
+ input_window_size=10,
+ output_window_size=6,
+ model_dir=model_dir,
+ num_features=1,
+ extra_feature_columns=exogenous_feature_columns,
+ num_units=10,
+ config=_SeedRunConfig())
+ times = numpy.arange(20, dtype=numpy.int64)
+ values = numpy.arange(20, dtype=dtype.as_numpy_dtype)
+ exogenous = numpy.arange(20, dtype=dtype.as_numpy_dtype)
+ features = {
+ feature_keys.TrainEvalFeatures.TIMES: times,
+ feature_keys.TrainEvalFeatures.VALUES: values,
+ "exogenous": exogenous
+ }
+ train_input_fn = input_pipeline.RandomWindowInputFn(
+ input_pipeline.NumpyReader(features), shuffle_seed=2, num_threads=1,
+ batch_size=16, window_size=16)
+ eval_input_fn = input_pipeline.RandomWindowInputFn(
+ input_pipeline.NumpyReader(features), shuffle_seed=3, num_threads=1,
+ batch_size=16, window_size=16)
+ estimator.train(input_fn=train_input_fn, steps=1)
+ evaluation = estimator.evaluate(
+ input_fn=eval_input_fn, steps=1)
+ self.assertAllEqual(evaluation["loss"], evaluation["average_loss"])
+ self.assertAllEqual([], evaluation["loss"].shape)
+
if __name__ == "__main__":
test.main()