aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/contrib/timeseries
diff options
context:
space:
mode:
authorGravatar Allen Lavoie <allenl@google.com>2018-04-23 16:46:41 -0700
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2018-04-23 16:48:59 -0700
commit84c73c2b4d0318bfd78a53ab6051169795604650 (patch)
tree02dcef8a6e945707e276a6d0c9cc52c484afc9fb /tensorflow/contrib/timeseries
parenta72155d58726d4dbb92d5d6b0f3290976bbdaa1c (diff)
TFTS: Support exogenous features in ARRegressor
They get flattened with the endogenous features as input to the model. Unlike endogenous features, they're specified for the whole window when making predictions. Adds an ARRegressor example which uses exogenous features. PiperOrigin-RevId: 194006630
Diffstat (limited to 'tensorflow/contrib/timeseries')
-rw-r--r--tensorflow/contrib/timeseries/examples/known_anomaly.py75
-rw-r--r--tensorflow/contrib/timeseries/examples/known_anomaly_test.py18
-rw-r--r--tensorflow/contrib/timeseries/python/timeseries/ar_model.py173
-rw-r--r--tensorflow/contrib/timeseries/python/timeseries/ar_model_test.py8
-rw-r--r--tensorflow/contrib/timeseries/python/timeseries/estimators.py11
-rw-r--r--tensorflow/contrib/timeseries/python/timeseries/estimators_test.py48
6 files changed, 255 insertions, 78 deletions
diff --git a/tensorflow/contrib/timeseries/examples/known_anomaly.py b/tensorflow/contrib/timeseries/examples/known_anomaly.py
index e77628ddd3..71621abc71 100644
--- a/tensorflow/contrib/timeseries/examples/known_anomaly.py
+++ b/tensorflow/contrib/timeseries/examples/known_anomaly.py
@@ -41,17 +41,8 @@ _MODULE_PATH = path.dirname(__file__)
_DATA_FILE = path.join(_MODULE_PATH, "data/changepoints.csv")
-def train_and_evaluate_exogenous(csv_file_name=_DATA_FILE, train_steps=300):
- """Training, evaluating, and predicting on a series with changepoints."""
-
- # Indicate the format of our exogenous feature, in this case a string
- # representing a boolean value.
- string_feature = tf.feature_column.categorical_column_with_vocabulary_list(
- key="is_changepoint", vocabulary_list=["no", "yes"])
- # Specify the way this feature is presented to the model, here using a one-hot
- # encoding.
- one_hot_feature = tf.feature_column.indicator_column(
- categorical_column=string_feature)
+def state_space_esitmator(exogenous_feature_columns):
+ """Constructs a StructuralEnsembleRegressor."""
def _exogenous_update_condition(times, features):
del times # unused
@@ -62,14 +53,48 @@ def train_and_evaluate_exogenous(csv_file_name=_DATA_FILE, train_steps=300):
# no changepoint.
return tf.equal(tf.squeeze(features["is_changepoint"], axis=-1), "yes")
- estimator = tf.contrib.timeseries.StructuralEnsembleRegressor(
- periodicities=12,
- # Extract a smooth period by constraining the number of latent values
- # being cycled between.
- cycle_num_latent_values=3,
- num_features=1,
- exogenous_feature_columns=[one_hot_feature],
- exogenous_update_condition=_exogenous_update_condition)
+ return (
+ tf.contrib.timeseries.StructuralEnsembleRegressor(
+ periodicities=12,
+ # Extract a smooth period by constraining the number of latent values
+ # being cycled between.
+ cycle_num_latent_values=3,
+ num_features=1,
+ exogenous_feature_columns=exogenous_feature_columns,
+ exogenous_update_condition=_exogenous_update_condition),
+ # Use truncated backpropagation with a window size of 64, batching
+ # together 4 of these windows (random offsets) per training step. Training
+ # with exogenous features often requires somewhat larger windows.
+ 4, 64)
+
+
+def autoregressive_esitmator(exogenous_feature_columns):
+ input_window_size = 8
+ output_window_size = 2
+ return (
+ tf.contrib.timeseries.ARRegressor(
+ periodicities=12,
+ num_features=1,
+ input_window_size=input_window_size,
+ output_window_size=output_window_size,
+ exogenous_feature_columns=exogenous_feature_columns),
+ 64, input_window_size + output_window_size)
+
+
+def train_and_evaluate_exogenous(
+ estimator_fn, csv_file_name=_DATA_FILE, train_steps=300):
+ """Training, evaluating, and predicting on a series with changepoints."""
+ # Indicate the format of our exogenous feature, in this case a string
+ # representing a boolean value.
+ string_feature = tf.feature_column.categorical_column_with_vocabulary_list(
+ key="is_changepoint", vocabulary_list=["no", "yes"])
+ # Specify the way this feature is presented to the model, here using a one-hot
+ # encoding.
+ one_hot_feature = tf.feature_column.indicator_column(
+ categorical_column=string_feature)
+
+ estimator, batch_size, window_size = estimator_fn(
+ exogenous_feature_columns=[one_hot_feature])
reader = tf.contrib.timeseries.CSVReader(
csv_file_name,
# Indicate the format of our CSV file. First we have two standard columns,
@@ -85,10 +110,7 @@ def train_and_evaluate_exogenous(csv_file_name=_DATA_FILE, train_steps=300):
# This CSV has a header line; here we just ignore it.
skip_header_lines=1)
train_input_fn = tf.contrib.timeseries.RandomWindowInputFn(
- # Use truncated backpropagation with a window size of 64, batching
- # together 4 of these windows (random offsets) per training step. Training
- # with exogenous features often requires somewhat larger windows.
- reader, batch_size=4, window_size=64)
+ reader, batch_size=batch_size, window_size=window_size)
estimator.train(input_fn=train_input_fn, steps=train_steps)
evaluation_input_fn = tf.contrib.timeseries.WholeDatasetInputFn(reader)
evaluation = estimator.evaluate(input_fn=evaluation_input_fn, steps=1)
@@ -145,7 +167,12 @@ def main(unused_argv):
if not HAS_MATPLOTLIB:
raise ImportError(
"Please install matplotlib to generate a plot from this example.")
- make_plot("Ignoring a known anomaly", *train_and_evaluate_exogenous())
+ make_plot("Ignoring a known anomaly (state space)",
+ *train_and_evaluate_exogenous(
+ estimator_fn=state_space_esitmator))
+ make_plot("Ignoring a known anomaly (autoregressive)",
+ *train_and_evaluate_exogenous(
+ estimator_fn=autoregressive_esitmator, train_steps=3000))
pyplot.show()
diff --git a/tensorflow/contrib/timeseries/examples/known_anomaly_test.py b/tensorflow/contrib/timeseries/examples/known_anomaly_test.py
index c3e307cad8..8c64f2e186 100644
--- a/tensorflow/contrib/timeseries/examples/known_anomaly_test.py
+++ b/tensorflow/contrib/timeseries/examples/known_anomaly_test.py
@@ -23,12 +23,24 @@ from tensorflow.contrib.timeseries.examples import known_anomaly
from tensorflow.python.platform import test
-class KnownAnaomalyExampleTest(test.TestCase):
+class KnownAnomalyExampleTest(test.TestCase):
- def test_shapes_and_variance_structural(self):
+ def test_shapes_and_variance_structural_ar(self):
(times, observed, all_times, mean, upper_limit, lower_limit,
anomaly_locations) = known_anomaly.train_and_evaluate_exogenous(
- train_steps=50)
+ train_steps=1, estimator_fn=known_anomaly.autoregressive_esitmator)
+ self.assertAllEqual(
+ anomaly_locations,
+ [25, 50, 75, 100, 125, 150, 175, 249])
+ self.assertAllEqual(all_times.shape, mean.shape)
+ self.assertAllEqual(all_times.shape, upper_limit.shape)
+ self.assertAllEqual(all_times.shape, lower_limit.shape)
+ self.assertAllEqual(times.shape, observed.shape)
+
+ def test_shapes_and_variance_structural_ssm(self):
+ (times, observed, all_times, mean, upper_limit, lower_limit,
+ anomaly_locations) = known_anomaly.train_and_evaluate_exogenous(
+ train_steps=50, estimator_fn=known_anomaly.state_space_esitmator)
self.assertAllEqual(
anomaly_locations,
[25, 50, 75, 100, 125, 150, 175, 249])
diff --git a/tensorflow/contrib/timeseries/python/timeseries/ar_model.py b/tensorflow/contrib/timeseries/python/timeseries/ar_model.py
index 4f6527a546..558d9480b4 100644
--- a/tensorflow/contrib/timeseries/python/timeseries/ar_model.py
+++ b/tensorflow/contrib/timeseries/python/timeseries/ar_model.py
@@ -60,7 +60,8 @@ class ARModel(model.TimeSeriesModel):
num_features,
num_time_buckets=10,
loss=NORMAL_LIKELIHOOD_LOSS,
- hidden_layer_sizes=None):
+ hidden_layer_sizes=None,
+ exogenous_feature_columns=None):
"""Constructs an auto-regressive model.
Args:
@@ -81,6 +82,11 @@ class ARModel(model.TimeSeriesModel):
observations and predictions, while the training loss is computed on
normalized data (if input statistics are available).
hidden_layer_sizes: list of sizes of hidden layers.
+ exogenous_feature_columns: A list of `tf.feature_column`s (for example
+ `tf.feature_column.embedding_column`) corresponding to exogenous
+ features which provide extra information to the model but are not part
+ of the series to be predicted. Passed to
+ `tf.feature_column.input_layer`.
"""
self.input_window_size = input_window_size
self.output_window_size = output_window_size
@@ -90,7 +96,12 @@ class ARModel(model.TimeSeriesModel):
self.window_size = self.input_window_size + self.output_window_size
self.loss = loss
super(ARModel, self).__init__(
- num_features=num_features)
+ num_features=num_features,
+ exogenous_feature_columns=exogenous_feature_columns)
+ if exogenous_feature_columns is not None:
+ self.exogenous_size = self._get_exogenous_embedding_shape()[-1]
+ else:
+ self.exogenous_size = 0
assert num_time_buckets > 0
self._buckets = int(num_time_buckets)
if periodicities is None or not periodicities:
@@ -110,7 +121,10 @@ class ARModel(model.TimeSeriesModel):
# that the serving input_receiver_fn gets placeholder shapes correct.
return (array_ops.zeros([self.input_window_size], dtype=dtypes.int64),
array_ops.zeros(
- [self.input_window_size, self.num_features], dtype=self.dtype))
+ [self.input_window_size, self.num_features], dtype=self.dtype),
+ array_ops.zeros(
+ [self.input_window_size, self.exogenous_size],
+ dtype=self.dtype))
# TODO(allenl,agarwal): Support sampling for AR.
def random_model_parameters(self, seed=None):
@@ -163,7 +177,7 @@ class ARModel(model.TimeSeriesModel):
activations.append((activation, activation_size))
return activations
- def prediction_ops(self, times, values):
+ def prediction_ops(self, times, values, exogenous_regressors):
"""Compute model predictions given input data.
Args:
@@ -173,6 +187,8 @@ class ARModel(model.TimeSeriesModel):
prediction times.
values: A [batch size, self.input_window_size, self.num_features] Tensor
with input features.
+ exogenous_regressors: A [batch size, self.window_size,
+ self.exogenous_size] Tensor with exogenous features.
Returns:
Tuple (predicted_mean, predicted_covariance), where each element is a
Tensor with shape [batch size, self.output_window_size,
@@ -183,25 +199,33 @@ class ARModel(model.TimeSeriesModel):
if self.input_window_size:
values.get_shape().assert_is_compatible_with(
[None, self.input_window_size, self.num_features])
+ if exogenous_regressors is not None:
+ exogenous_regressors.get_shape().assert_is_compatible_with(
+ [None, self.window_size, self.exogenous_size])
# Create input features.
+ activation_components = []
if self._periods:
_, time_features = self._compute_time_features(times)
activation_size = self.window_size * self._buckets * len(self._periods)
- activation = array_ops.reshape(time_features, [-1, activation_size])
+ activation_components.append(
+ array_ops.reshape(time_features, [-1, activation_size]))
else:
activation_size = 0
- activation = None
-
if self.input_window_size:
inp = array_ops.slice(values, [0, 0, 0], [-1, self.input_window_size, -1])
inp_size = self.input_window_size * self.num_features
inp = array_ops.reshape(inp, [-1, inp_size])
- if activation is not None:
- activation = array_ops.concat([inp, activation], 1)
- else:
- activation = inp
+ activation_components.append(inp)
activation_size += inp_size
+ if self.exogenous_size:
+ exogenous_size = self.window_size * self.exogenous_size
+ activation_size += exogenous_size
+ exogenous_flattened = array_ops.reshape(
+ exogenous_regressors, [-1, exogenous_size])
+ activation_components.append(exogenous_flattened)
assert activation_size
+ assert activation_components
+ activation = array_ops.concat(activation_components, axis=1)
activations.append((activation, activation_size))
# Create hidden layers.
activations += self._create_hidden_stack(activation, activation_size)
@@ -228,6 +252,19 @@ class ARModel(model.TimeSeriesModel):
math_ops.reduce_prod(array_ops.shape(targets)), loss_op.dtype)
return loss_op
+ def _process_exogenous_features(self, times, features):
+ embedded = super(ARModel, self)._process_exogenous_features(
+ times=times, features=features)
+ if embedded is None:
+ assert self.exogenous_size == 0
+ # No embeddings. Return a zero-size [batch, times, 0] array so we don't
+ # have to special case it downstream.
+ return array_ops.zeros(
+ array_ops.concat([array_ops.shape(times), constant_op.constant([0])],
+ axis=0))
+ else:
+ return embedded
+
# TODO(allenl, agarwal): Consider better ways of warm-starting predictions.
def predict(self, features):
"""Computes predictions multiple steps into the future.
@@ -243,6 +280,7 @@ class ARModel(model.TimeSeriesModel):
segment of the time series before `TIMES`. This data is used
to start of the autoregressive computation. This should have data for
at least self.input_window_size timesteps.
+ And any exogenous features, with shapes prefixed by shape of `TIMES`.
Returns:
A dictionary with keys, "mean", "covariance". The
values are Tensors of shape [batch_size, predict window size,
@@ -250,25 +288,39 @@ class ARModel(model.TimeSeriesModel):
"""
predict_times = math_ops.cast(
ops.convert_to_tensor(features[PredictionFeatures.TIMES]), dtypes.int32)
+ exogenous_regressors = self._process_exogenous_features(
+ times=predict_times,
+ features={key: value for key, value in features.items()
+ if key not in [TrainEvalFeatures.TIMES,
+ TrainEvalFeatures.VALUES,
+ PredictionFeatures.STATE_TUPLE]})
+ with ops.control_dependencies(
+ [check_ops.assert_equal(array_ops.shape(predict_times)[1],
+ array_ops.shape(exogenous_regressors)[1])]):
+ exogenous_regressors = array_ops.identity(exogenous_regressors)
batch_size = array_ops.shape(predict_times)[0]
num_predict_values = array_ops.shape(predict_times)[1]
prediction_iterations = ((num_predict_values + self.output_window_size - 1)
// self.output_window_size)
- # Pad predict_times so as to have exact multiple of self.output_window_size
- # values per example.
+ # Pad predict_times and exogenous regressors so as to have exact multiple of
+ # self.output_window_size values per example.
padding_size = (prediction_iterations * self.output_window_size -
num_predict_values)
- padding = array_ops.zeros([batch_size, padding_size], predict_times.dtype)
- predict_times = control_flow_ops.cond(
- padding_size > 0, lambda: array_ops.concat([predict_times, padding], 1),
- lambda: predict_times)
+ predict_times = array_ops.pad(
+ predict_times, [[0, 0], [0, padding_size]])
+ exogenous_regressors = array_ops.pad(
+ exogenous_regressors, [[0, 0], [0, padding_size], [0, 0]])
state = features[PredictionFeatures.STATE_TUPLE]
- (state_times, state_values) = state
+ (state_times, state_values, state_exogenous_regressors) = state
state_times = math_ops.cast(
ops.convert_to_tensor(state_times), dtypes.int32)
state_values = ops.convert_to_tensor(state_values, dtype=self.dtype)
+ state_exogenous_regressors = ops.convert_to_tensor(
+ state_exogenous_regressors, dtype=self.dtype)
initial_input_times = predict_times[:, :self.output_window_size]
+ initial_input_exogenous_regressors = (
+ exogenous_regressors[:, :self.output_window_size, :])
if self.input_window_size > 0:
initial_input_times = array_ops.concat(
[state_times[:, -self.input_window_size:], initial_input_times], 1)
@@ -279,6 +331,11 @@ class ARModel(model.TimeSeriesModel):
check_ops.assert_equal(values_size, times_size)
]):
initial_input_values = state_values[:, -self.input_window_size:, :]
+ initial_input_exogenous_regressors = array_ops.concat(
+ [state_exogenous_regressors[:, -self.input_window_size:, :],
+ initial_input_exogenous_regressors[
+ :, :self.output_window_size, :]],
+ axis=1)
else:
initial_input_values = 0
@@ -288,9 +345,10 @@ class ARModel(model.TimeSeriesModel):
return math_ops.less(iteration_number, prediction_iterations)
def _while_body(iteration_number, input_times, input_values,
- mean_ta, covariance_ta):
+ input_exogenous_regressors, mean_ta, covariance_ta):
"""Predict self.output_window_size values."""
- prediction_ops = self.prediction_ops(input_times, input_values)
+ prediction_ops = self.prediction_ops(
+ input_times, input_values, input_exogenous_regressors)
predicted_mean = prediction_ops["mean"]
predicted_covariance = prediction_ops["covariance"]
offset = self.output_window_size * gen_math_ops.minimum(
@@ -299,20 +357,33 @@ class ARModel(model.TimeSeriesModel):
if self.output_window_size < self.input_window_size:
new_input_values = array_ops.concat(
[input_values[:, self.output_window_size:, :], predicted_mean], 1)
+ new_input_exogenous_regressors = array_ops.concat(
+ [input_exogenous_regressors[:, -self.input_window_size:, :],
+ exogenous_regressors[
+ :, offset:offset + self.output_window_size, :]],
+ axis=1)
new_input_times = array_ops.concat([
- input_times[:, self.output_window_size:],
+ input_times[:, -self.input_window_size:],
predict_times[:, offset:offset + self.output_window_size]
], 1)
else:
new_input_values = predicted_mean[:, -self.input_window_size:, :]
+ new_input_exogenous_regressors = exogenous_regressors[
+ :,
+ offset - self.input_window_size:offset + self.output_window_size,
+ :]
new_input_times = predict_times[
:,
offset - self.input_window_size:offset + self.output_window_size]
else:
new_input_values = input_values
+ new_input_exogenous_regressors = exogenous_regressors[
+ :, offset:offset + self.output_window_size, :]
new_input_times = predict_times[:,
offset:offset + self.output_window_size]
new_input_times.set_shape(initial_input_times.get_shape())
+ new_input_exogenous_regressors.set_shape(
+ initial_input_exogenous_regressors.get_shape())
new_mean_ta = mean_ta.write(iteration_number, predicted_mean)
if isinstance(covariance_ta, tensor_array_ops.TensorArray):
new_covariance_ta = covariance_ta.write(iteration_number,
@@ -322,6 +393,7 @@ class ARModel(model.TimeSeriesModel):
return (iteration_number + 1,
new_input_times,
new_input_values,
+ new_input_exogenous_regressors,
new_mean_ta,
new_covariance_ta)
@@ -332,9 +404,13 @@ class ARModel(model.TimeSeriesModel):
if self.loss != ARModel.SQUARED_LOSS else 0.)
mean_ta_init = tensor_array_ops.TensorArray(
dtype=self.dtype, size=prediction_iterations)
- _, _, _, mean_ta, covariance_ta = control_flow_ops.while_loop(
+ _, _, _, _, mean_ta, covariance_ta = control_flow_ops.while_loop(
_while_condition, _while_body, [
- 0, initial_input_times, initial_input_values, mean_ta_init,
+ 0,
+ initial_input_times,
+ initial_input_values,
+ initial_input_exogenous_regressors,
+ mean_ta_init,
covariance_ta_init
])
@@ -366,11 +442,11 @@ class ARModel(model.TimeSeriesModel):
return {"mean": predicted_mean,
"covariance": predicted_covariance}
- def _process_window(self, features, mode):
+ def _process_window(self, features, mode, exogenous_regressors):
"""Compute model outputs on a single window of data."""
- # TODO(agarwal): Use exogenous features
times = math_ops.cast(features[TrainEvalFeatures.TIMES], dtypes.int64)
values = math_ops.cast(features[TrainEvalFeatures.VALUES], dtype=self.dtype)
+ exogenous_regressors = math_ops.cast(exogenous_regressors, dtype=self.dtype)
original_values = values
# Extra shape checking for the window size (above that in
@@ -395,7 +471,8 @@ class ARModel(model.TimeSeriesModel):
input_values = values[:, :self.input_window_size, :]
else:
input_values = None
- prediction_ops = self.prediction_ops(times, input_values)
+ prediction_ops = self.prediction_ops(
+ times, input_values, exogenous_regressors)
prediction = prediction_ops["mean"]
covariance = prediction_ops["covariance"]
targets = array_ops.slice(values, [0, self.input_window_size, 0],
@@ -419,7 +496,8 @@ class ARModel(model.TimeSeriesModel):
return model.ModelOutputs(
loss=loss,
end_state=(times[:, -self.input_window_size:],
- values[:, -self.input_window_size:, :]),
+ values[:, -self.input_window_size:, :],
+ exogenous_regressors[:, -self.input_window_size:, :]),
predictions={"mean": prediction, "covariance": covariance,
"observed": original_values[:, -self.output_window_size:]},
prediction_times=times[:, -self.output_window_size:])
@@ -454,17 +532,24 @@ class ARModel(model.TimeSeriesModel):
"""
features = {feature_name: ops.convert_to_tensor(feature_value)
for feature_name, feature_value in features.items()}
+ times = features[TrainEvalFeatures.TIMES]
+ exogenous_regressors = self._process_exogenous_features(
+ times=times,
+ features={key: value for key, value in features.items()
+ if key not in [TrainEvalFeatures.TIMES,
+ TrainEvalFeatures.VALUES,
+ PredictionFeatures.STATE_TUPLE]})
if mode == estimator_lib.ModeKeys.TRAIN:
# For training, we require the window size to be self.window_size as
# iterating sequentially on larger windows could introduce a bias.
- return self._process_window(features, mode=mode)
+ return self._process_window(
+ features, mode=mode, exogenous_regressors=exogenous_regressors)
elif mode == estimator_lib.ModeKeys.EVAL:
# For evaluation, we allow the user to pass in a larger window, in which
# case we try to cover as much of the window as possible without
# overlap. Quantitative evaluation is more efficient/correct with fixed
# windows matching self.window_size (as with training), but this looping
# allows easy plotting of "in-sample" predictions.
- times = features[TrainEvalFeatures.TIMES]
times.get_shape().assert_has_rank(2)
static_window_size = times.get_shape()[1].value
if (static_window_size is not None
@@ -500,7 +585,9 @@ class ARModel(model.TimeSeriesModel):
feature_name:
feature_value[:, base_offset:base_offset + self.window_size]
for feature_name, feature_value in features.items()},
- mode=mode)
+ mode=mode,
+ exogenous_regressors=exogenous_regressors[
+ :, base_offset:base_offset + self.window_size])
# This code needs to be updated if new predictions are added in
# self._process_window
assert len(model_outputs.predictions) == 3
@@ -525,7 +612,9 @@ class ARModel(model.TimeSeriesModel):
batch_size = array_ops.shape(times)[0]
prediction_shape = [batch_size, self.output_window_size * num_iterations,
self.num_features]
- previous_state_times, previous_state_values = state
+ (previous_state_times,
+ previous_state_values,
+ previous_state_exogenous_regressors) = state
# Make sure returned state always has windows of self.input_window_size,
# even if we were passed fewer than self.input_window_size points this
# time.
@@ -540,14 +629,24 @@ class ARModel(model.TimeSeriesModel):
self._scale_data(values)], axis=1)[:, -self.input_window_size:, :]
new_state_values.set_shape((None, self.input_window_size,
self.num_features))
+ new_exogenous_regressors = array_ops.concat(
+ [previous_state_exogenous_regressors,
+ exogenous_regressors], axis=1)[:, -self.input_window_size:, :]
+ new_exogenous_regressors.set_shape(
+ (None,
+ self.input_window_size,
+ self.exogenous_size))
else:
# There is no state to keep, and the strided slices above do not handle
# input_window_size=0.
new_state_times = previous_state_times
new_state_values = previous_state_values
+ new_exogenous_regressors = previous_state_exogenous_regressors
return model.ModelOutputs(
loss=math_ops.reduce_mean(loss_ta.stack(), axis=0),
- end_state=(new_state_times, new_state_values),
+ end_state=(new_state_times,
+ new_state_values,
+ new_exogenous_regressors),
predictions={
"mean": array_ops.reshape(
array_ops.transpose(mean_ta.stack(), [1, 0, 2, 3]),
@@ -604,7 +703,8 @@ class AnomalyMixtureARModel(ARModel):
num_features,
anomaly_distribution=GAUSSIAN_ANOMALY,
num_time_buckets=10,
- hidden_layer_sizes=None):
+ hidden_layer_sizes=None,
+ exogenous_feature_columns=None):
assert (anomaly_prior_probability < 1.0 and
anomaly_prior_probability > 0.0)
self._anomaly_prior_probability = anomaly_prior_probability
@@ -619,7 +719,8 @@ class AnomalyMixtureARModel(ARModel):
input_window_size=input_window_size,
output_window_size=output_window_size,
loss=ARModel.NORMAL_LIKELIHOOD_LOSS,
- hidden_layer_sizes=hidden_layer_sizes)
+ hidden_layer_sizes=hidden_layer_sizes,
+ exogenous_feature_columns=exogenous_feature_columns)
def _create_anomaly_ops(self, times, values, prediction_ops_dict):
anomaly_log_param = variable_scope.get_variable(
@@ -631,9 +732,9 @@ class AnomalyMixtureARModel(ARModel):
# distribution.
prediction_ops_dict["anomaly_params"] = gen_math_ops.exp(anomaly_log_param)
- def prediction_ops(self, times, values):
+ def prediction_ops(self, times, values, exogenous_regressors):
prediction_ops_dict = super(AnomalyMixtureARModel, self).prediction_ops(
- times, values)
+ times, values, exogenous_regressors)
self._create_anomaly_ops(times, values, prediction_ops_dict)
return prediction_ops_dict
diff --git a/tensorflow/contrib/timeseries/python/timeseries/ar_model_test.py b/tensorflow/contrib/timeseries/python/timeseries/ar_model_test.py
index 1e1ca4e77f..d078ac8d46 100644
--- a/tensorflow/contrib/timeseries/python/timeseries/ar_model_test.py
+++ b/tensorflow/contrib/timeseries/python/timeseries/ar_model_test.py
@@ -155,12 +155,15 @@ class ARModelTest(test.TestCase):
state_times = np.expand_dims(train_data_times[:input_window_size], 0)
state_values = np.expand_dims(
train_data_values[:input_window_size, :], 0)
+ state_exogenous = state_times[:, :, None][:, :, :0]
def prediction_input_fn():
return ({
PredictionFeatures.TIMES: training.limit_epochs(
predict_times, num_epochs=1),
- PredictionFeatures.STATE_TUPLE: (state_times, state_values)
+ PredictionFeatures.STATE_TUPLE: (state_times,
+ state_values,
+ state_exogenous)
}, {})
(predictions,) = tuple(estimator.predict(input_fn=prediction_input_fn))
predicted_mean = predictions["mean"][:, 0]
@@ -246,7 +249,8 @@ class ARModelTest(test.TestCase):
with session.Session():
predicted_values = model.predict({
PredictionFeatures.TIMES: [[4, 6, 10]],
- PredictionFeatures.STATE_TUPLE: ([[1, 2]], [[[1.], [2.]]])
+ PredictionFeatures.STATE_TUPLE: (
+ [[1, 2]], [[[1.], [2.]]], [[[], []]])
})
variables.global_variables_initializer().run()
self.assertAllEqual(predicted_values["mean"].eval().shape,
diff --git a/tensorflow/contrib/timeseries/python/timeseries/estimators.py b/tensorflow/contrib/timeseries/python/timeseries/estimators.py
index 886e1846e2..f4608ca2d1 100644
--- a/tensorflow/contrib/timeseries/python/timeseries/estimators.py
+++ b/tensorflow/contrib/timeseries/python/timeseries/estimators.py
@@ -190,7 +190,7 @@ class ARRegressor(TimeSeriesRegressor):
def __init__(
self, periodicities, input_window_size, output_window_size,
- num_features, num_time_buckets=10,
+ num_features, exogenous_feature_columns=None, num_time_buckets=10,
loss=ar_model.ARModel.NORMAL_LIKELIHOOD_LOSS, hidden_layer_sizes=None,
anomaly_prior_probability=None, anomaly_distribution=None,
optimizer=None, model_dir=None, config=None):
@@ -205,7 +205,12 @@ class ARRegressor(TimeSeriesRegressor):
output_window_size: Number of future time steps to predict. Note that
setting it to > 1 empirically seems to give a better fit.
num_features: The dimensionality of the time series (one for univariate,
- more than one for multivariate).
+ more than one for multivariate).
+ exogenous_feature_columns: A list of `tf.feature_column`s (for example
+ `tf.feature_column.embedding_column`) corresponding to exogenous
+ features which provide extra information to the model but are not part
+ of the series to be predicted. Passed to
+ `tf.feature_column.input_layer`.
num_time_buckets: Number of buckets into which to divide (time %
periodicity) for generating time based features.
loss: Loss function to use for training. Currently supported values are
@@ -241,6 +246,7 @@ class ARRegressor(TimeSeriesRegressor):
anomaly_distribution = ar_model.AnomalyMixtureARModel.GAUSSIAN_ANOMALY
model = ar_model.ARModel(
periodicities=periodicities, num_features=num_features,
+ exogenous_feature_columns=exogenous_feature_columns,
num_time_buckets=num_time_buckets,
input_window_size=input_window_size,
output_window_size=output_window_size, loss=loss,
@@ -255,6 +261,7 @@ class ARRegressor(TimeSeriesRegressor):
input_window_size=input_window_size,
output_window_size=output_window_size,
num_features=num_features,
+ exogenous_feature_columns=exogenous_feature_columns,
num_time_buckets=num_time_buckets,
hidden_layer_sizes=hidden_layer_sizes,
anomaly_prior_probability=anomaly_prior_probability,
diff --git a/tensorflow/contrib/timeseries/python/timeseries/estimators_test.py b/tensorflow/contrib/timeseries/python/timeseries/estimators_test.py
index 9f161c1695..eebee053f8 100644
--- a/tensorflow/contrib/timeseries/python/timeseries/estimators_test.py
+++ b/tensorflow/contrib/timeseries/python/timeseries/estimators_test.py
@@ -29,6 +29,7 @@ from tensorflow.contrib.timeseries.python.timeseries import saved_model_utils
from tensorflow.python.client import session
from tensorflow.python.estimator import estimator_lib
+from tensorflow.python.feature_column import feature_column
from tensorflow.python.framework import dtypes
from tensorflow.python.framework import ops
from tensorflow.python.platform import test
@@ -48,12 +49,17 @@ class TimeSeriesRegressorTest(test.TestCase):
def _fit_restore_fit_test_template(self, estimator_fn, dtype):
"""Tests restoring previously fit models."""
model_dir = tempfile.mkdtemp(dir=self.get_temp_dir())
- first_estimator = estimator_fn(model_dir)
+ exogenous_feature_columns = (
+ feature_column.numeric_column("exogenous"),
+ )
+ first_estimator = estimator_fn(model_dir, exogenous_feature_columns)
times = numpy.arange(20, dtype=numpy.int64)
values = numpy.arange(20, dtype=dtype.as_numpy_dtype)
+ exogenous = numpy.arange(20, dtype=dtype.as_numpy_dtype)
features = {
feature_keys.TrainEvalFeatures.TIMES: times,
- feature_keys.TrainEvalFeatures.VALUES: values
+ feature_keys.TrainEvalFeatures.VALUES: values,
+ "exogenous": exogenous
}
train_input_fn = input_pipeline.RandomWindowInputFn(
input_pipeline.NumpyReader(features), shuffle_seed=2, num_threads=1,
@@ -68,14 +74,19 @@ class TimeSeriesRegressorTest(test.TestCase):
first_loss_after_fit = first_estimator.evaluate(
input_fn=eval_input_fn, steps=1)["loss"]
self.assertLess(first_loss_after_fit, first_loss_before_fit)
- second_estimator = estimator_fn(model_dir)
+ second_estimator = estimator_fn(model_dir, exogenous_feature_columns)
second_estimator.train(input_fn=train_input_fn, steps=2)
whole_dataset_input_fn = input_pipeline.WholeDatasetInputFn(
input_pipeline.NumpyReader(features))
whole_dataset_evaluation = second_estimator.evaluate(
input_fn=whole_dataset_input_fn, steps=1)
+ exogenous_values_ten_steps = {
+ "exogenous": numpy.arange(
+ 10, dtype=dtype.as_numpy_dtype)[None, :, None]
+ }
predict_input_fn = input_pipeline.predict_continuation_input_fn(
evaluation=whole_dataset_evaluation,
+ exogenous_features=exogenous_values_ten_steps,
steps=10)
# Also tests that limit_epochs in predict_continuation_input_fn prevents
# infinite iteration
@@ -92,6 +103,7 @@ class TimeSeriesRegressorTest(test.TestCase):
saved_prediction = saved_model_utils.predict_continuation(
continue_from=whole_dataset_evaluation,
steps=10,
+ exogenous_features=exogenous_values_ten_steps,
signatures=signatures,
session=sess)
# Saved model predictions should be the same as Estimator predictions
@@ -104,7 +116,8 @@ class TimeSeriesRegressorTest(test.TestCase):
continue_from=whole_dataset_evaluation,
features={
feature_keys.FilteringFeatures.TIMES: times[None, -1] + 2,
- feature_keys.FilteringFeatures.VALUES: values[None, -1] + 2.
+ feature_keys.FilteringFeatures.VALUES: values[None, -1] + 2.,
+ "exogenous": values[None, -1, None] + 12.
},
signatures=signatures,
session=sess)
@@ -112,6 +125,10 @@ class TimeSeriesRegressorTest(test.TestCase):
second_saved_prediction = saved_model_utils.predict_continuation(
continue_from=first_filtering,
steps=1,
+ exogenous_features={
+ "exogenous": numpy.arange(
+ 1, dtype=dtype.as_numpy_dtype)[None, :, None]
+ },
signatures=signatures,
session=sess)
self.assertEqual(
@@ -122,7 +139,8 @@ class TimeSeriesRegressorTest(test.TestCase):
continue_from=first_filtering,
features={
feature_keys.FilteringFeatures.TIMES: times[-1] + 3,
- feature_keys.FilteringFeatures.VALUES: values[-1] + 3.
+ feature_keys.FilteringFeatures.VALUES: values[-1] + 3.,
+ "exogenous": values[-1, None] + 13.
},
signatures=signatures,
session=sess)
@@ -131,7 +149,8 @@ class TimeSeriesRegressorTest(test.TestCase):
six.assertCountEqual(
self,
[feature_keys.FilteringFeatures.TIMES,
- feature_keys.FilteringFeatures.VALUES],
+ feature_keys.FilteringFeatures.VALUES,
+ "exogenous"],
signatures.signature_def[
feature_keys.SavedModelLabels.COLD_START_FILTER].inputs.keys())
batch_numpy_times = numpy.tile(
@@ -142,7 +161,8 @@ class TimeSeriesRegressorTest(test.TestCase):
session=sess,
features={
feature_keys.FilteringFeatures.TIMES: batch_numpy_times,
- feature_keys.FilteringFeatures.VALUES: batch_numpy_values
+ feature_keys.FilteringFeatures.VALUES: batch_numpy_values,
+ "exogenous": 10. + batch_numpy_values
}
)
predict_times = numpy.tile(
@@ -150,26 +170,32 @@ class TimeSeriesRegressorTest(test.TestCase):
predictions = saved_model_utils.predict_continuation(
continue_from=state,
times=predict_times,
+ exogenous_features={
+ "exogenous": numpy.tile(numpy.arange(
+ 15, dtype=dtype.as_numpy_dtype), (10,))[None, :, None]
+ },
signatures=signatures,
session=sess)
self.assertAllEqual([10, 15, 1], predictions["mean"].shape)
def test_fit_restore_fit_ar_regressor(self):
- def _estimator_fn(model_dir):
+ def _estimator_fn(model_dir, exogenous_feature_columns):
return estimators.ARRegressor(
periodicities=10, input_window_size=10, output_window_size=6,
num_features=1, model_dir=model_dir, config=_SeedRunConfig(),
# This test is flaky with normal likelihood loss (could add more
# training iterations instead).
- loss=ar_model.ARModel.SQUARED_LOSS)
+ loss=ar_model.ARModel.SQUARED_LOSS,
+ exogenous_feature_columns=exogenous_feature_columns)
self._fit_restore_fit_test_template(_estimator_fn, dtype=dtypes.float32)
def test_fit_restore_fit_structural_ensemble_regressor(self):
dtype = dtypes.float32
- def _estimator_fn(model_dir):
+ def _estimator_fn(model_dir, exogenous_feature_columns):
return estimators.StructuralEnsembleRegressor(
num_features=1, periodicities=10, model_dir=model_dir, dtype=dtype,
- config=_SeedRunConfig())
+ config=_SeedRunConfig(),
+ exogenous_feature_columns=exogenous_feature_columns)
self._fit_restore_fit_test_template(_estimator_fn, dtype=dtype)