aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/contrib/timeseries
diff options
context:
space:
mode:
authorGravatar Martin Wicke <577277+martinwicke@users.noreply.github.com>2018-09-22 09:45:11 -0700
committerGravatar GitHub <noreply@github.com>2018-09-22 09:45:11 -0700
commit413ac36f33deb0c354dd687963d2410eab048970 (patch)
treefd4dc4e9fc5a76efd62c78c213b0e34983359256 /tensorflow/contrib/timeseries
parentc22d996c3d6a16db292bd3464b2ef7b91adae676 (diff)
parente692dda4c8b199555e2fa32132a7784e0893c870 (diff)
Merge branch 'master' into fix_expand_dims
Diffstat (limited to 'tensorflow/contrib/timeseries')
-rw-r--r--tensorflow/contrib/timeseries/__init__.py3
-rw-r--r--tensorflow/contrib/timeseries/examples/BUILD34
-rw-r--r--tensorflow/contrib/timeseries/examples/known_anomaly.py75
-rw-r--r--tensorflow/contrib/timeseries/examples/known_anomaly_test.py18
-rw-r--r--tensorflow/contrib/timeseries/examples/multivariate.py4
-rw-r--r--tensorflow/contrib/timeseries/examples/predict.py16
-rw-r--r--tensorflow/contrib/timeseries/python/timeseries/BUILD8
-rw-r--r--tensorflow/contrib/timeseries/python/timeseries/__init__.py1
-rw-r--r--tensorflow/contrib/timeseries/python/timeseries/ar_model.py482
-rw-r--r--tensorflow/contrib/timeseries/python/timeseries/ar_model_test.py96
-rw-r--r--tensorflow/contrib/timeseries/python/timeseries/estimators.py366
-rw-r--r--tensorflow/contrib/timeseries/python/timeseries/estimators_test.py124
-rw-r--r--tensorflow/contrib/timeseries/python/timeseries/head.py119
-rw-r--r--tensorflow/contrib/timeseries/python/timeseries/head_test.py163
-rw-r--r--tensorflow/contrib/timeseries/python/timeseries/input_pipeline_test.py6
-rw-r--r--tensorflow/contrib/timeseries/python/timeseries/math_utils.py6
-rw-r--r--tensorflow/contrib/timeseries/python/timeseries/math_utils_test.py25
-rw-r--r--tensorflow/contrib/timeseries/python/timeseries/model_utils_test.py2
-rw-r--r--tensorflow/contrib/timeseries/python/timeseries/state_management_test.py8
-rw-r--r--tensorflow/contrib/timeseries/python/timeseries/state_space_models/BUILD3
-rw-r--r--tensorflow/contrib/timeseries/python/timeseries/state_space_models/filtering_postprocessor_test.py2
-rw-r--r--tensorflow/contrib/timeseries/python/timeseries/state_space_models/kalman_filter.py6
-rw-r--r--tensorflow/contrib/timeseries/python/timeseries/state_space_models/kalman_filter_test.py22
-rw-r--r--tensorflow/contrib/timeseries/python/timeseries/state_space_models/state_space_model_test.py26
-rw-r--r--tensorflow/contrib/timeseries/python/timeseries/state_space_models/varma_test.py6
25 files changed, 1288 insertions, 333 deletions
diff --git a/tensorflow/contrib/timeseries/__init__.py b/tensorflow/contrib/timeseries/__init__.py
index 11db56b1b7..654a4db098 100644
--- a/tensorflow/contrib/timeseries/__init__.py
+++ b/tensorflow/contrib/timeseries/__init__.py
@@ -27,6 +27,9 @@
@@TrainEvalFeatures
@@FilteringResults
+
+@@TimeSeriesRegressor
+@@OneShotPredictionHead
"""
from __future__ import absolute_import
diff --git a/tensorflow/contrib/timeseries/examples/BUILD b/tensorflow/contrib/timeseries/examples/BUILD
index 32e948a009..21c0c30c19 100644
--- a/tensorflow/contrib/timeseries/examples/BUILD
+++ b/tensorflow/contrib/timeseries/examples/BUILD
@@ -8,14 +8,23 @@ licenses(["notice"]) # Apache 2.0
exports_files(["LICENSE"])
+config_setting(
+ name = "empty_condition",
+ values = {"define": "UNUSED=unused"},
+)
+
py_binary(
name = "predict",
srcs = ["predict.py"],
+ data = ["data/period_trend.csv"],
srcs_version = "PY2AND3",
tags = ["no_pip"],
- deps = [
- "//tensorflow:tensorflow_py",
+ deps = select({
+ ":empty_condition": [],
+ "//conditions:default": [],
+ }) + [
"//third_party/py/numpy",
+ "//tensorflow:tensorflow_py",
],
)
@@ -41,9 +50,12 @@ py_binary(
data = ["data/changepoints.csv"],
srcs_version = "PY2AND3",
tags = ["no_pip"],
- deps = [
- "//tensorflow:tensorflow_py",
+ deps = select({
+ ":empty_condition": [],
+ "//conditions:default": [],
+ }) + [
"//third_party/py/numpy",
+ "//tensorflow:tensorflow_py",
],
)
@@ -64,9 +76,12 @@ py_binary(
data = ["data/multivariate_level.csv"],
srcs_version = "PY2AND3",
tags = ["no_pip"],
- deps = [
- "//tensorflow:tensorflow_py",
+ deps = select({
+ ":empty_condition": [],
+ "//conditions:default": [],
+ }) + [
"//third_party/py/numpy",
+ "//tensorflow:tensorflow_py",
],
)
@@ -89,11 +104,14 @@ py_binary(
data = ["data/multivariate_periods.csv"],
srcs_version = "PY2AND3",
tags = ["no_pip"],
- deps = [
+ deps = select({
+ ":empty_condition": [],
+ "//conditions:default": [],
+ }) + [
+ "//third_party/py/numpy",
"//tensorflow:tensorflow_py",
"//tensorflow/contrib/timeseries/python/timeseries:estimators",
"//tensorflow/contrib/timeseries/python/timeseries:model",
- "//third_party/py/numpy",
],
)
diff --git a/tensorflow/contrib/timeseries/examples/known_anomaly.py b/tensorflow/contrib/timeseries/examples/known_anomaly.py
index e77628ddd3..1226433625 100644
--- a/tensorflow/contrib/timeseries/examples/known_anomaly.py
+++ b/tensorflow/contrib/timeseries/examples/known_anomaly.py
@@ -41,17 +41,8 @@ _MODULE_PATH = path.dirname(__file__)
_DATA_FILE = path.join(_MODULE_PATH, "data/changepoints.csv")
-def train_and_evaluate_exogenous(csv_file_name=_DATA_FILE, train_steps=300):
- """Training, evaluating, and predicting on a series with changepoints."""
-
- # Indicate the format of our exogenous feature, in this case a string
- # representing a boolean value.
- string_feature = tf.feature_column.categorical_column_with_vocabulary_list(
- key="is_changepoint", vocabulary_list=["no", "yes"])
- # Specify the way this feature is presented to the model, here using a one-hot
- # encoding.
- one_hot_feature = tf.feature_column.indicator_column(
- categorical_column=string_feature)
+def state_space_estimator(exogenous_feature_columns):
+ """Constructs a StructuralEnsembleRegressor."""
def _exogenous_update_condition(times, features):
del times # unused
@@ -62,14 +53,48 @@ def train_and_evaluate_exogenous(csv_file_name=_DATA_FILE, train_steps=300):
# no changepoint.
return tf.equal(tf.squeeze(features["is_changepoint"], axis=-1), "yes")
- estimator = tf.contrib.timeseries.StructuralEnsembleRegressor(
- periodicities=12,
- # Extract a smooth period by constraining the number of latent values
- # being cycled between.
- cycle_num_latent_values=3,
- num_features=1,
- exogenous_feature_columns=[one_hot_feature],
- exogenous_update_condition=_exogenous_update_condition)
+ return (
+ tf.contrib.timeseries.StructuralEnsembleRegressor(
+ periodicities=12,
+ # Extract a smooth period by constraining the number of latent values
+ # being cycled between.
+ cycle_num_latent_values=3,
+ num_features=1,
+ exogenous_feature_columns=exogenous_feature_columns,
+ exogenous_update_condition=_exogenous_update_condition),
+ # Use truncated backpropagation with a window size of 64, batching
+ # together 4 of these windows (random offsets) per training step. Training
+ # with exogenous features often requires somewhat larger windows.
+ 4, 64)
+
+
+def autoregressive_estimator(exogenous_feature_columns):
+ input_window_size = 8
+ output_window_size = 2
+ return (
+ tf.contrib.timeseries.ARRegressor(
+ periodicities=12,
+ num_features=1,
+ input_window_size=input_window_size,
+ output_window_size=output_window_size,
+ exogenous_feature_columns=exogenous_feature_columns),
+ 64, input_window_size + output_window_size)
+
+
+def train_and_evaluate_exogenous(
+ estimator_fn, csv_file_name=_DATA_FILE, train_steps=300):
+ """Training, evaluating, and predicting on a series with changepoints."""
+ # Indicate the format of our exogenous feature, in this case a string
+ # representing a boolean value.
+ string_feature = tf.feature_column.categorical_column_with_vocabulary_list(
+ key="is_changepoint", vocabulary_list=["no", "yes"])
+ # Specify the way this feature is presented to the model, here using a one-hot
+ # encoding.
+ one_hot_feature = tf.feature_column.indicator_column(
+ categorical_column=string_feature)
+
+ estimator, batch_size, window_size = estimator_fn(
+ exogenous_feature_columns=[one_hot_feature])
reader = tf.contrib.timeseries.CSVReader(
csv_file_name,
# Indicate the format of our CSV file. First we have two standard columns,
@@ -85,10 +110,7 @@ def train_and_evaluate_exogenous(csv_file_name=_DATA_FILE, train_steps=300):
# This CSV has a header line; here we just ignore it.
skip_header_lines=1)
train_input_fn = tf.contrib.timeseries.RandomWindowInputFn(
- # Use truncated backpropagation with a window size of 64, batching
- # together 4 of these windows (random offsets) per training step. Training
- # with exogenous features often requires somewhat larger windows.
- reader, batch_size=4, window_size=64)
+ reader, batch_size=batch_size, window_size=window_size)
estimator.train(input_fn=train_input_fn, steps=train_steps)
evaluation_input_fn = tf.contrib.timeseries.WholeDatasetInputFn(reader)
evaluation = estimator.evaluate(input_fn=evaluation_input_fn, steps=1)
@@ -145,7 +167,12 @@ def main(unused_argv):
if not HAS_MATPLOTLIB:
raise ImportError(
"Please install matplotlib to generate a plot from this example.")
- make_plot("Ignoring a known anomaly", *train_and_evaluate_exogenous())
+ make_plot("Ignoring a known anomaly (state space)",
+ *train_and_evaluate_exogenous(
+ estimator_fn=state_space_estimator))
+ make_plot("Ignoring a known anomaly (autoregressive)",
+ *train_and_evaluate_exogenous(
+ estimator_fn=autoregressive_estimator, train_steps=3000))
pyplot.show()
diff --git a/tensorflow/contrib/timeseries/examples/known_anomaly_test.py b/tensorflow/contrib/timeseries/examples/known_anomaly_test.py
index c3e307cad8..57ccf8f260 100644
--- a/tensorflow/contrib/timeseries/examples/known_anomaly_test.py
+++ b/tensorflow/contrib/timeseries/examples/known_anomaly_test.py
@@ -23,12 +23,24 @@ from tensorflow.contrib.timeseries.examples import known_anomaly
from tensorflow.python.platform import test
-class KnownAnaomalyExampleTest(test.TestCase):
+class KnownAnomalyExampleTest(test.TestCase):
- def test_shapes_and_variance_structural(self):
+ def test_shapes_and_variance_structural_ar(self):
(times, observed, all_times, mean, upper_limit, lower_limit,
anomaly_locations) = known_anomaly.train_and_evaluate_exogenous(
- train_steps=50)
+ train_steps=1, estimator_fn=known_anomaly.autoregressive_estimator)
+ self.assertAllEqual(
+ anomaly_locations,
+ [25, 50, 75, 100, 125, 150, 175, 249])
+ self.assertAllEqual(all_times.shape, mean.shape)
+ self.assertAllEqual(all_times.shape, upper_limit.shape)
+ self.assertAllEqual(all_times.shape, lower_limit.shape)
+ self.assertAllEqual(times.shape, observed.shape)
+
+ def test_shapes_and_variance_structural_ssm(self):
+ (times, observed, all_times, mean, upper_limit, lower_limit,
+ anomaly_locations) = known_anomaly.train_and_evaluate_exogenous(
+ train_steps=50, estimator_fn=known_anomaly.state_space_estimator)
self.assertAllEqual(
anomaly_locations,
[25, 50, 75, 100, 125, 150, 175, 249])
diff --git a/tensorflow/contrib/timeseries/examples/multivariate.py b/tensorflow/contrib/timeseries/examples/multivariate.py
index ed799542fd..e81cb18ad7 100644
--- a/tensorflow/contrib/timeseries/examples/multivariate.py
+++ b/tensorflow/contrib/timeseries/examples/multivariate.py
@@ -80,8 +80,8 @@ def multivariate_train_and_sample(
session=session, steps=1))
next_sample = numpy.random.multivariate_normal(
# Squeeze out the batch and series length dimensions (both 1).
- mean=numpy.squeeze(current_prediction["mean"], axis=[0, 1]),
- cov=numpy.squeeze(current_prediction["covariance"], axis=[0, 1]))
+ mean=numpy.squeeze(current_prediction["mean"], axis=(0, 1)),
+ cov=numpy.squeeze(current_prediction["covariance"], axis=(0, 1)))
# Update model state so that future predictions are conditional on the
# value we just sampled.
filtering_features = {
diff --git a/tensorflow/contrib/timeseries/examples/predict.py b/tensorflow/contrib/timeseries/examples/predict.py
index 8147d40caa..b036911314 100644
--- a/tensorflow/contrib/timeseries/examples/predict.py
+++ b/tensorflow/contrib/timeseries/examples/predict.py
@@ -19,6 +19,7 @@ from __future__ import division
from __future__ import print_function
import argparse
+import os
import sys
import numpy as np
@@ -40,6 +41,10 @@ except ImportError:
FLAGS = None
+_MODULE_PATH = os.path.dirname(__file__)
+_DEFAULT_DATA_FILE = os.path.join(_MODULE_PATH, "data/period_trend.csv")
+
+
def structural_ensemble_train_and_predict(csv_file_name):
# Cycle between 5 latent values over a period of 100. This leads to a very
# smooth periodic component (and a small model), which is a good fit for our
@@ -115,9 +120,12 @@ def main(unused_argv):
if not HAS_MATPLOTLIB:
raise ImportError(
"Please install matplotlib to generate a plot from this example.")
+ input_filename = FLAGS.input_filename
+ if input_filename is None:
+ input_filename = _DEFAULT_DATA_FILE
make_plot("Structural ensemble",
- *structural_ensemble_train_and_predict(FLAGS.input_filename))
- make_plot("AR", *ar_train_and_predict(FLAGS.input_filename))
+ *structural_ensemble_train_and_predict(input_filename))
+ make_plot("AR", *ar_train_and_predict(input_filename))
pyplot.show()
@@ -126,7 +134,7 @@ if __name__ == "__main__":
parser.add_argument(
"--input_filename",
type=str,
- required=True,
- help="Input csv file.")
+ required=False,
+ help="Input csv file (omit to use the data/period_trend.csv).")
FLAGS, unparsed = parser.parse_known_args()
tf.app.run(main=main, argv=[sys.argv[0]] + unparsed)
diff --git a/tensorflow/contrib/timeseries/python/timeseries/BUILD b/tensorflow/contrib/timeseries/python/timeseries/BUILD
index d2746032a0..c230919168 100644
--- a/tensorflow/contrib/timeseries/python/timeseries/BUILD
+++ b/tensorflow/contrib/timeseries/python/timeseries/BUILD
@@ -94,7 +94,6 @@ py_library(
"//tensorflow/python:training",
"//tensorflow/python:util",
"//tensorflow/python/estimator:estimator_py",
- "//tensorflow/python/estimator:export",
"//tensorflow/python/feature_column",
],
)
@@ -110,6 +109,7 @@ py_test(
"no_pip_gpu", # b/63391119
"nomsan", # Takes too long to run.
"notsan", # b/67865658
+ "optonly", # Takes too long to run without optimization.
],
deps = [
":ar_model",
@@ -148,17 +148,16 @@ py_library(
"//tensorflow/python:util",
"//tensorflow/python:variable_scope",
"//tensorflow/python/estimator:estimator_py",
- "//tensorflow/python/estimator:export",
- "//tensorflow/python/estimator:head",
- "//tensorflow/python/estimator:metric_keys",
],
)
py_test(
name = "head_test",
+ size = "large",
srcs = [
"head_test.py",
],
+ shard_count = 4,
srcs_version = "PY2AND3",
tags = ["no_pip_gpu"], # b/63391119
deps = [
@@ -183,6 +182,7 @@ py_test(
"//tensorflow/python/saved_model:loader",
"//tensorflow/python/saved_model:tag_constants",
"//third_party/py/numpy",
+ "@absl_py//absl/testing:parameterized",
"@six_archive//:six",
],
)
diff --git a/tensorflow/contrib/timeseries/python/timeseries/__init__.py b/tensorflow/contrib/timeseries/python/timeseries/__init__.py
index c683dad71d..8462138339 100644
--- a/tensorflow/contrib/timeseries/python/timeseries/__init__.py
+++ b/tensorflow/contrib/timeseries/python/timeseries/__init__.py
@@ -24,5 +24,6 @@ from tensorflow.contrib.timeseries.python.timeseries import saved_model_utils
from tensorflow.contrib.timeseries.python.timeseries.ar_model import *
from tensorflow.contrib.timeseries.python.timeseries.estimators import *
from tensorflow.contrib.timeseries.python.timeseries.feature_keys import *
+from tensorflow.contrib.timeseries.python.timeseries.head import *
from tensorflow.contrib.timeseries.python.timeseries.input_pipeline import *
# pylint: enable=wildcard-import
diff --git a/tensorflow/contrib/timeseries/python/timeseries/ar_model.py b/tensorflow/contrib/timeseries/python/timeseries/ar_model.py
index 4f6527a546..9bbe87e301 100644
--- a/tensorflow/contrib/timeseries/python/timeseries/ar_model.py
+++ b/tensorflow/contrib/timeseries/python/timeseries/ar_model.py
@@ -20,6 +20,7 @@ from __future__ import print_function
from tensorflow.contrib import distributions
+from tensorflow.contrib.rnn.python.ops import lstm_ops
from tensorflow.contrib.timeseries.python.timeseries import model
from tensorflow.contrib.timeseries.python.timeseries import model_utils
from tensorflow.contrib.timeseries.python.timeseries.feature_keys import PredictionFeatures
@@ -29,6 +30,9 @@ from tensorflow.python.estimator import estimator_lib
from tensorflow.python.framework import constant_op
from tensorflow.python.framework import dtypes
from tensorflow.python.framework import ops
+from tensorflow.python.keras.engine import sequential
+from tensorflow.python.keras.engine import training
+from tensorflow.python.keras.layers import core
from tensorflow.python.ops import array_ops
from tensorflow.python.ops import check_ops
from tensorflow.python.ops import control_flow_ops
@@ -40,15 +44,190 @@ from tensorflow.python.ops import tensor_array_ops
from tensorflow.python.ops import variable_scope
+class FlatPredictionModel(training.Model):
+ """Flattens input and output windows and puts them through dense layers.
+
+ This model does not operate on its own, but rather is a plugin to
+ `ARModel`. See `ARModel`'s constructor documentation
+ (`prediction_model_factory`) for a usage example.
+ """
+
+ def __init__(self,
+ num_features,
+ input_window_size,
+ output_window_size,
+ hidden_layer_sizes=None):
+ """Construct the flat prediction model.
+
+ Args:
+ num_features: number of input features per time step.
+ input_window_size: Number of past time steps of data to look at when doing
+ the regression.
+ output_window_size: Number of future time steps to predict. Note that
+ setting it to > 1 empirically seems to give a better fit.
+ hidden_layer_sizes: list of sizes of hidden layers.
+ """
+ super(FlatPredictionModel, self).__init__()
+ self._input_flatten = core.Flatten()
+ self._output_flatten = core.Flatten()
+ if hidden_layer_sizes:
+ self._hidden_layers = sequential.Sequential([
+ core.Dense(layer_size, activation=nn_ops.relu)
+ for layer_size in hidden_layer_sizes])
+ else:
+ self._hidden_layers = None
+ self._mean_transform = core.Dense(num_features * output_window_size,
+ name="predicted_mean")
+ self._covariance_transform = core.Dense(num_features * output_window_size,
+ name="log_sigma_square")
+ self._prediction_shape = [-1, output_window_size, num_features]
+
+ def call(self, input_window_features, output_window_features):
+ """Compute predictions from input and output windows.
+
+ Args:
+ input_window_features: A floating point Tensor with shape [batch size,
+ input window size, input features]. The batch dimension may not have
+ static shape information, but the window size and number of input
+ features are known at graph construction time and recorded in the static
+ shape information for the `input_window_features` `Tensor`. Note that
+ `input_window_size` may be zero.
+ output_window_features: A floating point Tensor with shape [batch size,
+ output window size, output features]. As with `input_window_features`,
+ the last two dimensions have static shape information. If there are no
+ output features, the size of the last dimension will be zero.
+ Returns:
+ A dictionary of predictions with keys "mean" and "covariance" (only
+ diagonal covariances are currently supported). Each has shape
+ [batch size, output window size, num_features], where num_features is the
+ same as the constructor argument.
+ """
+ if input_window_features.shape[1].value == 0:
+ # TODO(allenl): Make reshape()'s static shape information work on
+ # zero-size Tensors? Currently this special case is required because
+ # otherwise the Dense layers get unknown last dimensions.
+ activation = self._output_flatten(output_window_features)
+ elif output_window_features.shape[2].value == 0:
+ activation = self._input_flatten(input_window_features)
+ else:
+ activation = array_ops.concat(
+ [self._input_flatten(input_window_features),
+ self._output_flatten(output_window_features)],
+ axis=1)
+ if self._hidden_layers:
+ activation = self._hidden_layers(activation)
+ predicted_mean = array_ops.reshape(
+ self._mean_transform(activation),
+ self._prediction_shape)
+ predicted_covariance = array_ops.reshape(
+ gen_math_ops.exp(self._covariance_transform(activation)),
+ self._prediction_shape)
+ return {"mean": predicted_mean,
+ "covariance": predicted_covariance}
+
+
+class LSTMPredictionModel(training.Model):
+ """A simple encoder/decoder model using an LSTM.
+
+ This model does not operate on its own, but rather is a plugin to
+ `ARModel`. See `ARModel`'s constructor documentation
+ (`prediction_model_factory`) for a usage example.
+ """
+
+ def __init__(self,
+ num_features,
+ input_window_size,
+ output_window_size,
+ num_units=128):
+ """Construct the LSTM prediction model.
+
+ Args:
+ num_features: number of input features per time step.
+ input_window_size: Number of past time steps of data to look at when doing
+ the regression.
+ output_window_size: Number of future time steps to predict. Note that
+ setting it to > 1 empirically seems to give a better fit.
+ num_units: The number of units in the encoder and decoder LSTM cells.
+ """
+ super(LSTMPredictionModel, self).__init__()
+ self._encoder = lstm_ops.LSTMBlockFusedCell(
+ num_units=num_units, name="encoder")
+ self._decoder = lstm_ops.LSTMBlockFusedCell(
+ num_units=num_units, name="decoder")
+ self._mean_transform = core.Dense(num_features,
+ name="mean_transform")
+ self._covariance_transform = core.Dense(num_features,
+ name="covariance_transform")
+
+ def call(self, input_window_features, output_window_features):
+ """Compute predictions from input and output windows."""
+ # Convert to time major
+ input_window_features = array_ops.transpose(input_window_features,
+ [1, 0, 2])
+ output_window_features = array_ops.transpose(output_window_features,
+ [1, 0, 2])
+ _, encoder_state = self._encoder(
+ input_window_features, dtype=self.dtype)
+ decoder_output, _ = self._decoder(
+ output_window_features, dtype=self.dtype,
+ initial_state=encoder_state)
+
+ # Switch back to batch major
+ decoder_output = array_ops.transpose(decoder_output, [1, 0, 2])
+ predicted_mean = self._mean_transform(decoder_output)
+ predicted_covariance = gen_math_ops.exp(
+ self._covariance_transform(decoder_output))
+ return {"mean": predicted_mean,
+ "covariance": predicted_covariance}
+
+
class ARModel(model.TimeSeriesModel):
"""Auto-regressive model, both linear and non-linear.
Features to the model include time and values of input_window_size timesteps,
- and times for output_window_size timesteps. These are passed through zero or
- more hidden layers, and then fed to a loss function (e.g. squared loss).
+ and times for output_window_size timesteps. These are passed through a
+ configurable prediction model, and then fed to a loss function (e.g. squared
+ loss).
Note that this class can also be used to regress against time only by setting
the input_window_size to zero.
+
+ Each periodicity in the `periodicities` arg is divided by the
+ `num_time_buckets` into time buckets that are represented as features added
+ to the model.
+
+ A good heuristic for picking an appropriate periodicity for a given data set
+ would be the length of cycles in the data. For example, energy usage in a
+ home is typically cyclic each day. If the time feature in a home energy
+ usage dataset is in the unit of hours, then 24 would be an appropriate
+ periodicity. Similarly, a good heuristic for `num_time_buckets` is how often
+ the data is expected to change within the cycle. For the aforementioned home
+ energy usage dataset and periodicity of 24, then 48 would be a reasonable
+ value if usage is expected to change every half hour.
+
+ Each feature's value for a given example with time t is the difference
+ between t and the start of the time bucket it falls under. If it doesn't fall
+ under a feature's associated time bucket, then that feature's value is zero.
+
+ For example: if `periodicities` = (9, 12) and `num_time_buckets` = 3, then 6
+ features would be added to the model, 3 for periodicity 9 and 3 for
+ periodicity 12.
+
+ For an example data point where t = 17:
+ - It's in the 3rd time bucket for periodicity 9 (2nd period is 9-18 and 3rd
+ time bucket is 15-18)
+ - It's in the 2nd time bucket for periodicity 12 (2nd period is 12-24 and
+ 2nd time bucket is between 16-20).
+
+ Therefore the 6 added features for this row with t = 17 would be:
+
+ # Feature name (periodicity#_timebucket#), feature value
+ P9_T1, 0 # not in first time bucket
+ P9_T2, 0 # not in second time bucket
+ P9_T3, 2 # 17 - 15 since 15 is the start of the 3rd time bucket
+ P12_T1, 0 # not in first time bucket
+ P12_T2, 1 # 17 - 16 since 16 is the start of the 2nd time bucket
+ P12_T3, 0 # not in third time bucket
"""
SQUARED_LOSS = "squared_loss"
NORMAL_LIKELIHOOD_LOSS = "normal_likelihood_loss"
@@ -58,39 +237,61 @@ class ARModel(model.TimeSeriesModel):
input_window_size,
output_window_size,
num_features,
+ prediction_model_factory=FlatPredictionModel,
num_time_buckets=10,
loss=NORMAL_LIKELIHOOD_LOSS,
- hidden_layer_sizes=None):
+ exogenous_feature_columns=None):
"""Constructs an auto-regressive model.
Args:
periodicities: periodicities of the input data, in the same units as the
- time feature. Note this can be a single value or a list of values for
+ time feature (for example 24 if feeding hourly data with a daily
+ periodicity, or 60 * 24 if feeding minute-level data with daily
+ periodicity). Note this can be a single value or a list of values for
multiple periodicities.
input_window_size: Number of past time steps of data to look at when doing
the regression.
output_window_size: Number of future time steps to predict. Note that
setting it to > 1 empirically seems to give a better fit.
num_features: number of input features per time step.
+ prediction_model_factory: A callable taking arguments `num_features`,
+ `input_window_size`, and `output_window_size` and returning a
+ `tf.keras.Model`. The `Model`'s `call()` takes two arguments: an input
+ window and an output window, and returns a dictionary of predictions.
+ See `FlatPredictionModel` for an example. Example usage:
+
+ ```python model = ar_model.ARModel( periodicities=2, num_features=3,
+ prediction_model_factory=functools.partial( FlatPredictionModel,
+ hidden_layer_sizes=[10, 10])) ```
+
+ The default model computes predictions as a linear function of flattened
+ input and output windows.
num_time_buckets: Number of buckets into which to divide (time %
- periodicity) for generating time based features.
+ periodicity). This value multiplied by the number of periodicities is
+ the number of time features added to the model.
loss: Loss function to use for training. Currently supported values are
SQUARED_LOSS and NORMAL_LIKELIHOOD_LOSS. Note that for
NORMAL_LIKELIHOOD_LOSS, we train the covariance term as well. For
SQUARED_LOSS, the evaluation loss is reported based on un-scaled
observations and predictions, while the training loss is computed on
normalized data (if input statistics are available).
- hidden_layer_sizes: list of sizes of hidden layers.
+ exogenous_feature_columns: A list of `tf.feature_column`s (for example
+ `tf.feature_column.embedding_column`) corresponding to
+ features which provide extra information to the model but are not part
+ of the series to be predicted.
"""
+ self._model_factory = prediction_model_factory
self.input_window_size = input_window_size
self.output_window_size = output_window_size
- if hidden_layer_sizes is None:
- hidden_layer_sizes = []
- self.hidden_layer_sizes = hidden_layer_sizes
self.window_size = self.input_window_size + self.output_window_size
self.loss = loss
super(ARModel, self).__init__(
- num_features=num_features)
+ num_features=num_features,
+ exogenous_feature_columns=exogenous_feature_columns)
+ if exogenous_feature_columns is not None:
+ self.exogenous_size = self._get_exogenous_embedding_shape()[-1]
+ else:
+ self.exogenous_size = 0
assert num_time_buckets > 0
self._buckets = int(num_time_buckets)
if periodicities is None or not periodicities:
@@ -98,19 +299,35 @@ class ARModel(model.TimeSeriesModel):
elif (not isinstance(periodicities, list) and
not isinstance(periodicities, tuple)):
periodicities = [periodicities]
- self._periods = [int(p) for p in periodicities]
- for p in self._periods:
+ self._periodicities = [int(p) for p in periodicities]
+ for p in self._periodicities:
assert p > 0
- assert len(self._periods) or self.input_window_size
+ assert len(self._periodicities) or self.input_window_size
assert output_window_size > 0
+ def initialize_graph(self, input_statistics=None):
+ super(ARModel, self).initialize_graph(input_statistics=input_statistics)
+ self._model_scope = variable_scope.variable_scope(
+ # The trailing slash means we strip all enclosing variable_scopes, which
+ # unfortunately is necessary because the model gets called inside and
+ # outside a "while" scope (for prediction and training respectively),
+ # and the variables names need to match.
+ "model/", use_resource=True)
+ self._model_instance = self._model_factory(
+ num_features=self.num_features,
+ input_window_size=self.input_window_size,
+ output_window_size=self.output_window_size)
+
def get_start_state(self):
# State which matches the format we'll return later. Typically this will not
# be used by the model directly, but the shapes and dtypes should match so
# that the serving input_receiver_fn gets placeholder shapes correct.
return (array_ops.zeros([self.input_window_size], dtype=dtypes.int64),
array_ops.zeros(
- [self.input_window_size, self.num_features], dtype=self.dtype))
+ [self.input_window_size, self.num_features], dtype=self.dtype),
+ array_ops.zeros(
+ [self.input_window_size, self.exogenous_size],
+ dtype=self.dtype))
# TODO(allenl,agarwal): Support sampling for AR.
def random_model_parameters(self, seed=None):
@@ -152,18 +369,7 @@ class ARModel(model.TimeSeriesModel):
return array_ops.reshape(predicted_mean,
[-1, self.output_window_size, self.num_features])
- def _create_hidden_stack(self, activation, activation_size):
- activations = []
- for layer_number, layer_size in enumerate(self.hidden_layer_sizes):
- # TODO(agarwal): Migrate to fully_connected in tf slim
- activation = model_utils.fully_connected(
- activation, activation_size, layer_size,
- name="layer_{}".format(layer_number))
- activation_size = layer_size
- activations.append((activation, activation_size))
- return activations
-
- def prediction_ops(self, times, values):
+ def prediction_ops(self, times, values, exogenous_regressors):
"""Compute model predictions given input data.
Args:
@@ -173,45 +379,82 @@ class ARModel(model.TimeSeriesModel):
prediction times.
values: A [batch size, self.input_window_size, self.num_features] Tensor
with input features.
+ exogenous_regressors: A [batch size, self.window_size,
+ self.exogenous_size] Tensor with exogenous features.
Returns:
Tuple (predicted_mean, predicted_covariance), where each element is a
Tensor with shape [batch size, self.output_window_size,
self.num_features].
"""
times.get_shape().assert_is_compatible_with([None, self.window_size])
- activations = []
+ batch_size = array_ops.shape(times)[0]
if self.input_window_size:
values.get_shape().assert_is_compatible_with(
[None, self.input_window_size, self.num_features])
+ if exogenous_regressors is not None:
+ exogenous_regressors.get_shape().assert_is_compatible_with(
+ [None, self.window_size, self.exogenous_size])
# Create input features.
- if self._periods:
+ input_window_features = []
+ input_feature_size = 0
+ output_window_features = []
+ output_feature_size = 0
+ if self._periodicities:
_, time_features = self._compute_time_features(times)
- activation_size = self.window_size * self._buckets * len(self._periods)
- activation = array_ops.reshape(time_features, [-1, activation_size])
- else:
- activation_size = 0
- activation = None
-
+ num_time_features = self._buckets * len(self._periodicities)
+ time_features = array_ops.reshape(
+ time_features,
+ [batch_size,
+ self.window_size,
+ num_time_features])
+ input_time_features, output_time_features = array_ops.split(
+ time_features, (self.input_window_size, self.output_window_size),
+ axis=1)
+ input_feature_size += num_time_features
+ output_feature_size += num_time_features
+ input_window_features.append(input_time_features)
+ output_window_features.append(output_time_features)
if self.input_window_size:
inp = array_ops.slice(values, [0, 0, 0], [-1, self.input_window_size, -1])
- inp_size = self.input_window_size * self.num_features
- inp = array_ops.reshape(inp, [-1, inp_size])
- if activation is not None:
- activation = array_ops.concat([inp, activation], 1)
- else:
- activation = inp
- activation_size += inp_size
- assert activation_size
- activations.append((activation, activation_size))
- # Create hidden layers.
- activations += self._create_hidden_stack(activation, activation_size)
- # Create mean and convariance ops.
- predicted_mean = self._predicted_mean_op(activations)
- predicted_covariance = self._predicted_covariance_op(activations,
- self.num_features)
- return {"activations": activations,
- "mean": predicted_mean,
- "covariance": predicted_covariance}
+ input_window_features.append(
+ array_ops.reshape(
+ inp,
+ [batch_size, self.input_window_size, self.num_features]))
+ input_feature_size += self.num_features
+ if self.exogenous_size:
+ input_exogenous_features, output_exogenous_features = array_ops.split(
+ exogenous_regressors,
+ (self.input_window_size, self.output_window_size),
+ axis=1)
+ input_feature_size += self.exogenous_size
+ output_feature_size += self.exogenous_size
+ input_window_features.append(input_exogenous_features)
+ output_window_features.append(output_exogenous_features)
+ assert input_window_features
+ input_window_features = array_ops.concat(input_window_features, axis=2)
+ if output_window_features:
+ output_window_features = array_ops.concat(output_window_features, axis=2)
+ else:
+ output_window_features = array_ops.zeros(
+ [batch_size, self.output_window_size, 0],
+ dtype=self.dtype)
+ static_batch_size = times.get_shape()[0].value
+ input_window_features.set_shape(
+ [static_batch_size, self.input_window_size, input_feature_size])
+ output_window_features.set_shape(
+ [static_batch_size, self.output_window_size, output_feature_size])
+ return self._output_window_predictions(input_window_features,
+ output_window_features)
+
+ def _output_window_predictions(
+ self, input_window_features, output_window_features):
+ with self._model_scope:
+ predictions = self._model_instance(
+ input_window_features, output_window_features)
+ result_shape = [None, self.output_window_size, self.num_features]
+ for v in predictions.values():
+ v.set_shape(result_shape)
+ return predictions
def loss_op(self, targets, prediction_ops):
"""Create loss_op."""
@@ -228,6 +471,19 @@ class ARModel(model.TimeSeriesModel):
math_ops.reduce_prod(array_ops.shape(targets)), loss_op.dtype)
return loss_op
+ def _process_exogenous_features(self, times, features):
+ embedded = super(ARModel, self)._process_exogenous_features(
+ times=times, features=features)
+ if embedded is None:
+ assert self.exogenous_size == 0
+ # No embeddings. Return a zero-size [batch, times, 0] array so we don't
+ # have to special case it downstream.
+ return array_ops.zeros(
+ array_ops.concat([array_ops.shape(times), constant_op.constant([0])],
+ axis=0))
+ else:
+ return embedded
+
# TODO(allenl, agarwal): Consider better ways of warm-starting predictions.
def predict(self, features):
"""Computes predictions multiple steps into the future.
@@ -243,32 +499,49 @@ class ARModel(model.TimeSeriesModel):
segment of the time series before `TIMES`. This data is used
to start of the autoregressive computation. This should have data for
at least self.input_window_size timesteps.
+ And any exogenous features, with shapes prefixed by shape of `TIMES`.
Returns:
A dictionary with keys, "mean", "covariance". The
values are Tensors of shape [batch_size, predict window size,
num_features] and correspond to the values passed in `TIMES`.
"""
+ if not self._graph_initialized:
+ self.initialize_graph()
predict_times = math_ops.cast(
ops.convert_to_tensor(features[PredictionFeatures.TIMES]), dtypes.int32)
+ exogenous_regressors = self._process_exogenous_features(
+ times=predict_times,
+ features={key: value for key, value in features.items()
+ if key not in [TrainEvalFeatures.TIMES,
+ TrainEvalFeatures.VALUES,
+ PredictionFeatures.STATE_TUPLE]})
+ with ops.control_dependencies(
+ [check_ops.assert_equal(array_ops.shape(predict_times)[1],
+ array_ops.shape(exogenous_regressors)[1])]):
+ exogenous_regressors = array_ops.identity(exogenous_regressors)
batch_size = array_ops.shape(predict_times)[0]
num_predict_values = array_ops.shape(predict_times)[1]
prediction_iterations = ((num_predict_values + self.output_window_size - 1)
// self.output_window_size)
- # Pad predict_times so as to have exact multiple of self.output_window_size
- # values per example.
+ # Pad predict_times and exogenous regressors so as to have exact multiple of
+ # self.output_window_size values per example.
padding_size = (prediction_iterations * self.output_window_size -
num_predict_values)
- padding = array_ops.zeros([batch_size, padding_size], predict_times.dtype)
- predict_times = control_flow_ops.cond(
- padding_size > 0, lambda: array_ops.concat([predict_times, padding], 1),
- lambda: predict_times)
+ predict_times = array_ops.pad(
+ predict_times, [[0, 0], [0, padding_size]])
+ exogenous_regressors = array_ops.pad(
+ exogenous_regressors, [[0, 0], [0, padding_size], [0, 0]])
state = features[PredictionFeatures.STATE_TUPLE]
- (state_times, state_values) = state
+ (state_times, state_values, state_exogenous_regressors) = state
state_times = math_ops.cast(
ops.convert_to_tensor(state_times), dtypes.int32)
state_values = ops.convert_to_tensor(state_values, dtype=self.dtype)
+ state_exogenous_regressors = ops.convert_to_tensor(
+ state_exogenous_regressors, dtype=self.dtype)
initial_input_times = predict_times[:, :self.output_window_size]
+ initial_input_exogenous_regressors = (
+ exogenous_regressors[:, :self.output_window_size, :])
if self.input_window_size > 0:
initial_input_times = array_ops.concat(
[state_times[:, -self.input_window_size:], initial_input_times], 1)
@@ -279,6 +552,11 @@ class ARModel(model.TimeSeriesModel):
check_ops.assert_equal(values_size, times_size)
]):
initial_input_values = state_values[:, -self.input_window_size:, :]
+ initial_input_exogenous_regressors = array_ops.concat(
+ [state_exogenous_regressors[:, -self.input_window_size:, :],
+ initial_input_exogenous_regressors[
+ :, :self.output_window_size, :]],
+ axis=1)
else:
initial_input_values = 0
@@ -288,9 +566,10 @@ class ARModel(model.TimeSeriesModel):
return math_ops.less(iteration_number, prediction_iterations)
def _while_body(iteration_number, input_times, input_values,
- mean_ta, covariance_ta):
+ input_exogenous_regressors, mean_ta, covariance_ta):
"""Predict self.output_window_size values."""
- prediction_ops = self.prediction_ops(input_times, input_values)
+ prediction_ops = self.prediction_ops(
+ input_times, input_values, input_exogenous_regressors)
predicted_mean = prediction_ops["mean"]
predicted_covariance = prediction_ops["covariance"]
offset = self.output_window_size * gen_math_ops.minimum(
@@ -299,20 +578,33 @@ class ARModel(model.TimeSeriesModel):
if self.output_window_size < self.input_window_size:
new_input_values = array_ops.concat(
[input_values[:, self.output_window_size:, :], predicted_mean], 1)
+ new_input_exogenous_regressors = array_ops.concat(
+ [input_exogenous_regressors[:, -self.input_window_size:, :],
+ exogenous_regressors[
+ :, offset:offset + self.output_window_size, :]],
+ axis=1)
new_input_times = array_ops.concat([
- input_times[:, self.output_window_size:],
+ input_times[:, -self.input_window_size:],
predict_times[:, offset:offset + self.output_window_size]
], 1)
else:
new_input_values = predicted_mean[:, -self.input_window_size:, :]
+ new_input_exogenous_regressors = exogenous_regressors[
+ :,
+ offset - self.input_window_size:offset + self.output_window_size,
+ :]
new_input_times = predict_times[
:,
offset - self.input_window_size:offset + self.output_window_size]
else:
new_input_values = input_values
+ new_input_exogenous_regressors = exogenous_regressors[
+ :, offset:offset + self.output_window_size, :]
new_input_times = predict_times[:,
offset:offset + self.output_window_size]
new_input_times.set_shape(initial_input_times.get_shape())
+ new_input_exogenous_regressors.set_shape(
+ initial_input_exogenous_regressors.get_shape())
new_mean_ta = mean_ta.write(iteration_number, predicted_mean)
if isinstance(covariance_ta, tensor_array_ops.TensorArray):
new_covariance_ta = covariance_ta.write(iteration_number,
@@ -322,6 +614,7 @@ class ARModel(model.TimeSeriesModel):
return (iteration_number + 1,
new_input_times,
new_input_values,
+ new_input_exogenous_regressors,
new_mean_ta,
new_covariance_ta)
@@ -332,9 +625,13 @@ class ARModel(model.TimeSeriesModel):
if self.loss != ARModel.SQUARED_LOSS else 0.)
mean_ta_init = tensor_array_ops.TensorArray(
dtype=self.dtype, size=prediction_iterations)
- _, _, _, mean_ta, covariance_ta = control_flow_ops.while_loop(
+ _, _, _, _, mean_ta, covariance_ta = control_flow_ops.while_loop(
_while_condition, _while_body, [
- 0, initial_input_times, initial_input_values, mean_ta_init,
+ 0,
+ initial_input_times,
+ initial_input_values,
+ initial_input_exogenous_regressors,
+ mean_ta_init,
covariance_ta_init
])
@@ -366,11 +663,11 @@ class ARModel(model.TimeSeriesModel):
return {"mean": predicted_mean,
"covariance": predicted_covariance}
- def _process_window(self, features, mode):
+ def _process_window(self, features, mode, exogenous_regressors):
"""Compute model outputs on a single window of data."""
- # TODO(agarwal): Use exogenous features
times = math_ops.cast(features[TrainEvalFeatures.TIMES], dtypes.int64)
values = math_ops.cast(features[TrainEvalFeatures.VALUES], dtype=self.dtype)
+ exogenous_regressors = math_ops.cast(exogenous_regressors, dtype=self.dtype)
original_values = values
# Extra shape checking for the window size (above that in
@@ -395,7 +692,8 @@ class ARModel(model.TimeSeriesModel):
input_values = values[:, :self.input_window_size, :]
else:
input_values = None
- prediction_ops = self.prediction_ops(times, input_values)
+ prediction_ops = self.prediction_ops(
+ times, input_values, exogenous_regressors)
prediction = prediction_ops["mean"]
covariance = prediction_ops["covariance"]
targets = array_ops.slice(values, [0, self.input_window_size, 0],
@@ -419,7 +717,8 @@ class ARModel(model.TimeSeriesModel):
return model.ModelOutputs(
loss=loss,
end_state=(times[:, -self.input_window_size:],
- values[:, -self.input_window_size:, :]),
+ values[:, -self.input_window_size:, :],
+ exogenous_regressors[:, -self.input_window_size:, :]),
predictions={"mean": prediction, "covariance": covariance,
"observed": original_values[:, -self.output_window_size:]},
prediction_times=times[:, -self.output_window_size:])
@@ -454,17 +753,24 @@ class ARModel(model.TimeSeriesModel):
"""
features = {feature_name: ops.convert_to_tensor(feature_value)
for feature_name, feature_value in features.items()}
+ times = features[TrainEvalFeatures.TIMES]
+ exogenous_regressors = self._process_exogenous_features(
+ times=times,
+ features={key: value for key, value in features.items()
+ if key not in [TrainEvalFeatures.TIMES,
+ TrainEvalFeatures.VALUES,
+ PredictionFeatures.STATE_TUPLE]})
if mode == estimator_lib.ModeKeys.TRAIN:
# For training, we require the window size to be self.window_size as
# iterating sequentially on larger windows could introduce a bias.
- return self._process_window(features, mode=mode)
+ return self._process_window(
+ features, mode=mode, exogenous_regressors=exogenous_regressors)
elif mode == estimator_lib.ModeKeys.EVAL:
# For evaluation, we allow the user to pass in a larger window, in which
# case we try to cover as much of the window as possible without
# overlap. Quantitative evaluation is more efficient/correct with fixed
# windows matching self.window_size (as with training), but this looping
# allows easy plotting of "in-sample" predictions.
- times = features[TrainEvalFeatures.TIMES]
times.get_shape().assert_has_rank(2)
static_window_size = times.get_shape()[1].value
if (static_window_size is not None
@@ -500,7 +806,9 @@ class ARModel(model.TimeSeriesModel):
feature_name:
feature_value[:, base_offset:base_offset + self.window_size]
for feature_name, feature_value in features.items()},
- mode=mode)
+ mode=mode,
+ exogenous_regressors=exogenous_regressors[
+ :, base_offset:base_offset + self.window_size])
# This code needs to be updated if new predictions are added in
# self._process_window
assert len(model_outputs.predictions) == 3
@@ -525,7 +833,9 @@ class ARModel(model.TimeSeriesModel):
batch_size = array_ops.shape(times)[0]
prediction_shape = [batch_size, self.output_window_size * num_iterations,
self.num_features]
- previous_state_times, previous_state_values = state
+ (previous_state_times,
+ previous_state_values,
+ previous_state_exogenous_regressors) = state
# Make sure returned state always has windows of self.input_window_size,
# even if we were passed fewer than self.input_window_size points this
# time.
@@ -540,14 +850,24 @@ class ARModel(model.TimeSeriesModel):
self._scale_data(values)], axis=1)[:, -self.input_window_size:, :]
new_state_values.set_shape((None, self.input_window_size,
self.num_features))
+ new_exogenous_regressors = array_ops.concat(
+ [previous_state_exogenous_regressors,
+ exogenous_regressors], axis=1)[:, -self.input_window_size:, :]
+ new_exogenous_regressors.set_shape(
+ (None,
+ self.input_window_size,
+ self.exogenous_size))
else:
# There is no state to keep, and the strided slices above do not handle
# input_window_size=0.
new_state_times = previous_state_times
new_state_values = previous_state_values
+ new_exogenous_regressors = previous_state_exogenous_regressors
return model.ModelOutputs(
loss=math_ops.reduce_mean(loss_ta.stack(), axis=0),
- end_state=(new_state_times, new_state_values),
+ end_state=(new_state_times,
+ new_state_values,
+ new_exogenous_regressors),
predictions={
"mean": array_ops.reshape(
array_ops.transpose(mean_ta.stack(), [1, 0, 2, 3]),
@@ -564,12 +884,12 @@ class ARModel(model.TimeSeriesModel):
def _compute_time_features(self, time):
"""Compute some features on the time value."""
batch_size = array_ops.shape(time)[0]
- num_periods = len(self._periods)
+ num_periods = len(self._periodicities)
# Reshape to 3D.
periods = constant_op.constant(
- self._periods, shape=[1, 1, num_periods, 1], dtype=time.dtype)
+ self._periodicities, shape=[1, 1, num_periods, 1], dtype=time.dtype)
time = array_ops.reshape(time, [batch_size, -1, 1, 1])
- window_offset = time / self._periods
+ window_offset = time / self._periodicities
# Cast to appropriate type and scale to [0, 1) range
mod = (math_ops.cast(time % periods, self.dtype) * self._buckets /
math_ops.cast(periods, self.dtype))
@@ -602,9 +922,10 @@ class AnomalyMixtureARModel(ARModel):
input_window_size,
output_window_size,
num_features,
+ prediction_model_factory=FlatPredictionModel,
anomaly_distribution=GAUSSIAN_ANOMALY,
num_time_buckets=10,
- hidden_layer_sizes=None):
+ exogenous_feature_columns=None):
assert (anomaly_prior_probability < 1.0 and
anomaly_prior_probability > 0.0)
self._anomaly_prior_probability = anomaly_prior_probability
@@ -619,7 +940,8 @@ class AnomalyMixtureARModel(ARModel):
input_window_size=input_window_size,
output_window_size=output_window_size,
loss=ARModel.NORMAL_LIKELIHOOD_LOSS,
- hidden_layer_sizes=hidden_layer_sizes)
+ prediction_model_factory=prediction_model_factory,
+ exogenous_feature_columns=exogenous_feature_columns)
def _create_anomaly_ops(self, times, values, prediction_ops_dict):
anomaly_log_param = variable_scope.get_variable(
@@ -631,9 +953,9 @@ class AnomalyMixtureARModel(ARModel):
# distribution.
prediction_ops_dict["anomaly_params"] = gen_math_ops.exp(anomaly_log_param)
- def prediction_ops(self, times, values):
+ def prediction_ops(self, times, values, exogenous_regressors):
prediction_ops_dict = super(AnomalyMixtureARModel, self).prediction_ops(
- times, values)
+ times, values, exogenous_regressors)
self._create_anomaly_ops(times, values, prediction_ops_dict)
return prediction_ops_dict
diff --git a/tensorflow/contrib/timeseries/python/timeseries/ar_model_test.py b/tensorflow/contrib/timeseries/python/timeseries/ar_model_test.py
index 1e1ca4e77f..de547f835d 100644
--- a/tensorflow/contrib/timeseries/python/timeseries/ar_model_test.py
+++ b/tensorflow/contrib/timeseries/python/timeseries/ar_model_test.py
@@ -18,12 +18,13 @@ from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
+import functools
+
import numpy as np
+from tensorflow.contrib.timeseries.python.timeseries import ar_model
from tensorflow.contrib.timeseries.python.timeseries import input_pipeline
from tensorflow.contrib.timeseries.python.timeseries import test_utils
-from tensorflow.contrib.timeseries.python.timeseries.ar_model import AnomalyMixtureARModel
-from tensorflow.contrib.timeseries.python.timeseries.ar_model import ARModel
from tensorflow.contrib.timeseries.python.timeseries.estimators import ARRegressor
from tensorflow.contrib.timeseries.python.timeseries.feature_keys import PredictionFeatures
from tensorflow.contrib.timeseries.python.timeseries.feature_keys import TrainEvalFeatures
@@ -91,7 +92,7 @@ class ARModelTest(test.TestCase):
np.random.seed(3)
data_noise_stddev = 0.2
if max_loss is None:
- if loss == ARModel.NORMAL_LIKELIHOOD_LOSS:
+ if loss == ar_model.ARModel.NORMAL_LIKELIHOOD_LOSS:
max_loss = 1.0
else:
max_loss = 0.05 / (data_noise_stddev ** 2)
@@ -137,7 +138,7 @@ class ARModelTest(test.TestCase):
test_loss = test_evaluation["loss"]
logging.info("Final test loss: %f", test_loss)
self.assertLess(test_loss, max_loss)
- if loss == ARModel.SQUARED_LOSS:
+ if loss == ar_model.ARModel.SQUARED_LOSS:
# Test that the evaluation loss is reported without input scaling.
self.assertAllClose(
test_loss,
@@ -155,18 +156,21 @@ class ARModelTest(test.TestCase):
state_times = np.expand_dims(train_data_times[:input_window_size], 0)
state_values = np.expand_dims(
train_data_values[:input_window_size, :], 0)
+ state_exogenous = state_times[:, :, None][:, :, :0]
def prediction_input_fn():
return ({
PredictionFeatures.TIMES: training.limit_epochs(
predict_times, num_epochs=1),
- PredictionFeatures.STATE_TUPLE: (state_times, state_values)
+ PredictionFeatures.STATE_TUPLE: (state_times,
+ state_values,
+ state_exogenous)
}, {})
(predictions,) = tuple(estimator.predict(input_fn=prediction_input_fn))
predicted_mean = predictions["mean"][:, 0]
true_values = predict_true_values[0, :, 0]
- if loss == ARModel.NORMAL_LIKELIHOOD_LOSS:
+ if loss == ar_model.ARModel.NORMAL_LIKELIHOOD_LOSS:
variances = predictions["covariance"][:, 0]
standard_deviations = np.sqrt(variances)
# Note that we may get tighter bounds with more training steps.
@@ -177,26 +181,26 @@ class ARModelTest(test.TestCase):
def test_time_regression_squared(self):
self.train_helper(input_window_size=0,
train_steps=350,
- loss=ARModel.SQUARED_LOSS)
+ loss=ar_model.ARModel.SQUARED_LOSS)
def test_autoregression_squared(self):
self.train_helper(input_window_size=15,
- loss=ARModel.SQUARED_LOSS)
+ loss=ar_model.ARModel.SQUARED_LOSS)
def test_autoregression_short_input_window(self):
self.train_helper(input_window_size=8,
- loss=ARModel.SQUARED_LOSS)
+ loss=ar_model.ARModel.SQUARED_LOSS)
def test_autoregression_normal(self):
self.train_helper(input_window_size=10,
- loss=ARModel.NORMAL_LIKELIHOOD_LOSS,
+ loss=ar_model.ARModel.NORMAL_LIKELIHOOD_LOSS,
train_steps=300,
- max_loss=1.5,
+ max_loss=50., # Just make sure there are no exceptions.
anomaly_distribution=None)
def test_autoregression_normal_multiple_periods(self):
self.train_helper(input_window_size=10,
- loss=ARModel.NORMAL_LIKELIHOOD_LOSS,
+ loss=ar_model.ARModel.NORMAL_LIKELIHOOD_LOSS,
max_loss=2.0,
multiple_periods=True,
anomaly_distribution=None)
@@ -204,15 +208,15 @@ class ARModelTest(test.TestCase):
def test_autoregression_normal_anomalies_normal(self):
self.train_helper(
input_window_size=10,
- loss=ARModel.NORMAL_LIKELIHOOD_LOSS,
- anomaly_distribution=AnomalyMixtureARModel.GAUSSIAN_ANOMALY)
+ loss=ar_model.ARModel.NORMAL_LIKELIHOOD_LOSS,
+ anomaly_distribution=ar_model.AnomalyMixtureARModel.GAUSSIAN_ANOMALY)
def test_autoregression_normal_anomalies_cauchy(self):
self.train_helper(
input_window_size=10,
max_loss=1.5,
- loss=ARModel.NORMAL_LIKELIHOOD_LOSS,
- anomaly_distribution=AnomalyMixtureARModel.CAUCHY_ANOMALY)
+ loss=ar_model.ARModel.NORMAL_LIKELIHOOD_LOSS,
+ anomaly_distribution=ar_model.AnomalyMixtureARModel.CAUCHY_ANOMALY)
def test_wrong_window_size(self):
estimator = ARRegressor(
@@ -234,19 +238,43 @@ class ARModelTest(test.TestCase):
with self.assertRaisesRegexp(ValueError, "requires a window of at least"):
estimator.evaluate(input_fn=_bad_window_size_input_fn, steps=1)
- def test_predictions_direct(self):
+ def test_predictions_direct_flat(self):
+ g = ops.Graph()
+ with g.as_default():
+ model = ar_model.ARModel(periodicities=2,
+ num_features=1,
+ num_time_buckets=10,
+ input_window_size=2,
+ output_window_size=2,
+ prediction_model_factory=functools.partial(
+ ar_model.FlatPredictionModel,
+ hidden_layer_sizes=[40, 10]))
+ with session.Session():
+ predicted_values = model.predict({
+ PredictionFeatures.TIMES: [[4, 6, 10]],
+ PredictionFeatures.STATE_TUPLE: (
+ [[1, 2]], [[[1.], [2.]]], [[[], []]])
+ })
+ variables.global_variables_initializer().run()
+ self.assertAllEqual(predicted_values["mean"].eval().shape,
+ [1, 3, 1])
+
+ def test_predictions_direct_lstm(self):
g = ops.Graph()
with g.as_default():
- model = ARModel(periodicities=2,
- num_features=1,
- num_time_buckets=10,
- input_window_size=2,
- output_window_size=2,
- hidden_layer_sizes=[40, 10])
+ model = ar_model.ARModel(periodicities=2,
+ num_features=1,
+ num_time_buckets=10,
+ input_window_size=2,
+ output_window_size=2,
+ prediction_model_factory=functools.partial(
+ ar_model.LSTMPredictionModel,
+ num_units=16))
with session.Session():
predicted_values = model.predict({
PredictionFeatures.TIMES: [[4, 6, 10]],
- PredictionFeatures.STATE_TUPLE: ([[1, 2]], [[[1.], [2.]]])
+ PredictionFeatures.STATE_TUPLE: (
+ [[1, 2]], [[[1.], [2.]]], [[[], []]])
})
variables.global_variables_initializer().run()
self.assertAllEqual(predicted_values["mean"].eval().shape,
@@ -255,11 +283,11 @@ class ARModelTest(test.TestCase):
def test_long_eval(self):
g = ops.Graph()
with g.as_default():
- model = ARModel(periodicities=2,
- num_features=1,
- num_time_buckets=10,
- input_window_size=2,
- output_window_size=1)
+ model = ar_model.ARModel(periodicities=2,
+ num_features=1,
+ num_time_buckets=10,
+ input_window_size=2,
+ output_window_size=1)
raw_features = {
TrainEvalFeatures.TIMES: [[1, 3, 5, 7, 11]],
TrainEvalFeatures.VALUES: [[[1.], [2.], [3.], [4.], [5.]]]}
@@ -305,11 +333,11 @@ class ARModelTest(test.TestCase):
def test_long_eval_discard_indivisible(self):
g = ops.Graph()
with g.as_default():
- model = ARModel(periodicities=2,
- num_features=1,
- num_time_buckets=10,
- input_window_size=2,
- output_window_size=2)
+ model = ar_model.ARModel(periodicities=2,
+ num_features=1,
+ num_time_buckets=10,
+ input_window_size=2,
+ output_window_size=2)
raw_features = {
TrainEvalFeatures.TIMES: [[1, 3, 5, 7, 11]],
TrainEvalFeatures.VALUES: [[[1.], [2.], [3.], [4.], [5.]]]}
diff --git a/tensorflow/contrib/timeseries/python/timeseries/estimators.py b/tensorflow/contrib/timeseries/python/timeseries/estimators.py
index 886e1846e2..af68aa03cf 100644
--- a/tensorflow/contrib/timeseries/python/timeseries/estimators.py
+++ b/tensorflow/contrib/timeseries/python/timeseries/estimators.py
@@ -18,6 +18,8 @@ from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
+import functools
+
from tensorflow.contrib.timeseries.python.timeseries import ar_model
from tensorflow.contrib.timeseries.python.timeseries import feature_keys
from tensorflow.contrib.timeseries.python.timeseries import head as ts_head_lib
@@ -28,6 +30,7 @@ from tensorflow.contrib.timeseries.python.timeseries.state_space_models import s
from tensorflow.contrib.timeseries.python.timeseries.state_space_models.filtering_postprocessor import StateInterpolatingAnomalyDetector
from tensorflow.python.estimator import estimator_lib
+from tensorflow.python.estimator.canned import optimizers
from tensorflow.python.estimator.export import export_lib
from tensorflow.python.feature_column import feature_column
from tensorflow.python.framework import dtypes
@@ -35,6 +38,7 @@ from tensorflow.python.framework import ops
from tensorflow.python.framework import tensor_shape
from tensorflow.python.framework import tensor_util
from tensorflow.python.ops import array_ops
+from tensorflow.python.ops import math_ops
from tensorflow.python.ops import parsing_ops
from tensorflow.python.training import training as train
from tensorflow.python.util import nest
@@ -61,7 +65,10 @@ class TimeSeriesRegressor(estimator_lib.Estimator):
input_statistics_generator = math_utils.InputStatisticsFromMiniBatch(
dtype=model.dtype, num_features=model.num_features)
if state_manager is None:
- state_manager = state_management.PassthroughStateManager()
+ if isinstance(model, ar_model.ARModel):
+ state_manager = state_management.FilteringOnlyStateManager()
+ else:
+ state_manager = state_management.PassthroughStateManager()
if optimizer is None:
optimizer = train.AdamOptimizer(0.02)
self._model = model
@@ -74,12 +81,137 @@ class TimeSeriesRegressor(estimator_lib.Estimator):
model_dir=model_dir,
config=config)
- # TODO(allenl): A parsing input receiver function, which takes a serialized
- # tf.Example containing all features (times, values, any exogenous features)
- # and serialized model state (possibly also as a tf.Example).
- def build_raw_serving_input_receiver_fn(self,
- default_batch_size=None,
- default_series_length=None):
+ def _model_start_state_placeholders(
+ self, batch_size_tensor, static_batch_size=None):
+ """Creates placeholders with zeroed start state for the current model."""
+ gathered_state = {}
+ # Models may not know the shape of their state without creating some
+ # variables/ops. Avoid polluting the default graph by making a new one. We
+ # use only static metadata from the returned Tensors.
+ with ops.Graph().as_default():
+ self._model.initialize_graph()
+ # Evaluate the initial state as same-dtype "zero" values. These zero
+ # constants aren't used, but are necessary for feeding to
+ # placeholder_with_default for the "cold start" case where state is not
+ # fed to the model.
+ def _zeros_like_constant(tensor):
+ return tensor_util.constant_value(array_ops.zeros_like(tensor))
+ start_state = nest.map_structure(
+ _zeros_like_constant, self._model.get_start_state())
+ for prefixed_state_name, state in ts_head_lib.state_to_dictionary(
+ start_state).items():
+ state_shape_with_batch = tensor_shape.TensorShape(
+ (static_batch_size,)).concatenate(state.shape)
+ default_state_broadcast = array_ops.tile(
+ state[None, ...],
+ multiples=array_ops.concat(
+ [batch_size_tensor[None],
+ array_ops.ones(len(state.shape), dtype=dtypes.int32)],
+ axis=0))
+ gathered_state[prefixed_state_name] = array_ops.placeholder_with_default(
+ input=default_state_broadcast,
+ name=prefixed_state_name,
+ shape=state_shape_with_batch)
+ return gathered_state
+
+ def build_one_shot_parsing_serving_input_receiver_fn(
+ self, filtering_length, prediction_length, default_batch_size=None,
+ values_input_dtype=None, truncate_values=False):
+ """Build an input_receiver_fn for export_savedmodel accepting tf.Examples.
+
+ Only compatible with `OneShotPredictionHead` (see `head`).
+
+ Args:
+ filtering_length: The number of time steps used as input to the model, for
+ which values are provided. If more than `filtering_length` values are
+ provided (via `truncate_values`), only the first `filtering_length`
+ values are used.
+ prediction_length: The number of time steps requested as predictions from
+ the model. Times and all exogenous features must be provided for these
+ steps.
+ default_batch_size: If specified, must be a scalar integer. Sets the batch
+ size in the static shape information of all feature Tensors, which means
+ only this batch size will be accepted by the exported model. If None
+ (default), static shape information for batch sizes is omitted.
+ values_input_dtype: An optional dtype specification for values in the
+ tf.Example protos (either float32 or int64, since these are the numeric
+ types supported by tf.Example). After parsing, values are cast to the
+ model's dtype (float32 or float64).
+ truncate_values: If True, expects `filtering_length + prediction_length`
+ values to be provided, but only uses the first `filtering_length`. If
+ False (default), exactly `filtering_length` values must be provided.
+
+ Returns:
+ An input_receiver_fn which may be passed to the Estimator's
+ export_savedmodel.
+
+ Expects features contained in a vector of serialized tf.Examples with
+ shape [batch size] (dtype `tf.string`), each tf.Example containing
+ features with the following shapes:
+ times: [filtering_length + prediction_length] integer
+ values: [filtering_length, num features] floating point. If
+ `truncate_values` is True, expects `filtering_length +
+ prediction_length` values but only uses the first `filtering_length`.
+ all exogenous features: [filtering_length + prediction_length, ...]
+ (various dtypes)
+ """
+ if values_input_dtype is None:
+ values_input_dtype = dtypes.float32
+ if truncate_values:
+ values_proto_length = filtering_length + prediction_length
+ else:
+ values_proto_length = filtering_length
+
+ def _serving_input_receiver_fn():
+ """A receiver function to be passed to export_savedmodel."""
+ times_column = feature_column.numeric_column(
+ key=feature_keys.TrainEvalFeatures.TIMES, dtype=dtypes.int64)
+ values_column = feature_column.numeric_column(
+ key=feature_keys.TrainEvalFeatures.VALUES, dtype=values_input_dtype,
+ shape=(self._model.num_features,))
+ parsed_features_no_sequence = (
+ feature_column.make_parse_example_spec(
+ list(self._model.exogenous_feature_columns)
+ + [times_column, values_column]))
+ parsed_features = {}
+ for key, feature_spec in parsed_features_no_sequence.items():
+ if isinstance(feature_spec, parsing_ops.FixedLenFeature):
+ if key == feature_keys.TrainEvalFeatures.VALUES:
+ parsed_features[key] = feature_spec._replace(
+ shape=((values_proto_length,)
+ + feature_spec.shape))
+ else:
+ parsed_features[key] = feature_spec._replace(
+ shape=((filtering_length + prediction_length,)
+ + feature_spec.shape))
+ elif feature_spec.dtype == dtypes.string:
+ parsed_features[key] = parsing_ops.FixedLenFeature(
+ shape=(filtering_length + prediction_length,),
+ dtype=dtypes.string)
+ else: # VarLenFeature
+ raise ValueError("VarLenFeatures not supported, got %s for key %s"
+ % (feature_spec, key))
+ tfexamples = array_ops.placeholder(
+ shape=[default_batch_size], dtype=dtypes.string, name="input")
+ features = parsing_ops.parse_example(
+ serialized=tfexamples,
+ features=parsed_features)
+ features[feature_keys.TrainEvalFeatures.TIMES] = array_ops.squeeze(
+ features[feature_keys.TrainEvalFeatures.TIMES], axis=-1)
+ features[feature_keys.TrainEvalFeatures.VALUES] = math_ops.cast(
+ features[feature_keys.TrainEvalFeatures.VALUES],
+ dtype=self._model.dtype)[:, :filtering_length]
+ features.update(
+ self._model_start_state_placeholders(
+ batch_size_tensor=array_ops.shape(
+ features[feature_keys.TrainEvalFeatures.TIMES])[0],
+ static_batch_size=default_batch_size))
+ return export_lib.ServingInputReceiver(
+ features, {"examples": tfexamples})
+ return _serving_input_receiver_fn
+
+ def build_raw_serving_input_receiver_fn(
+ self, default_batch_size=None, default_series_length=None):
"""Build an input_receiver_fn for export_savedmodel which accepts arrays.
Automatically creates placeholders for exogenous `FeatureColumn`s passed to
@@ -144,34 +276,10 @@ class TimeSeriesRegressor(estimator_lib.Estimator):
+ batch_only_feature_shape[1:])
placeholders[feature_key] = array_ops.placeholder(
dtype=value_dtype, name=feature_key, shape=feature_shape)
- # Models may not know the shape of their state without creating some
- # variables/ops. Avoid polluting the default graph by making a new one. We
- # use only static metadata from the returned Tensors.
- with ops.Graph().as_default():
- self._model.initialize_graph()
- # Evaluate the initial state as same-dtype "zero" values. These zero
- # constants aren't used, but are necessary for feeding to
- # placeholder_with_default for the "cold start" case where state is not
- # fed to the model.
- def _zeros_like_constant(tensor):
- return tensor_util.constant_value(array_ops.zeros_like(tensor))
- start_state = nest.map_structure(
- _zeros_like_constant, self._model.get_start_state())
batch_size_tensor = array_ops.shape(time_placeholder)[0]
- for prefixed_state_name, state in ts_head_lib.state_to_dictionary(
- start_state).items():
- state_shape_with_batch = tensor_shape.TensorShape(
- (default_batch_size,)).concatenate(state.shape)
- default_state_broadcast = array_ops.tile(
- state[None, ...],
- multiples=array_ops.concat(
- [batch_size_tensor[None],
- array_ops.ones(len(state.shape), dtype=dtypes.int32)],
- axis=0))
- placeholders[prefixed_state_name] = array_ops.placeholder_with_default(
- input=default_state_broadcast,
- name=prefixed_state_name,
- shape=state_shape_with_batch)
+ placeholders.update(
+ self._model_start_state_placeholders(
+ batch_size_tensor, static_batch_size=default_batch_size))
return export_lib.ServingInputReceiver(placeholders, placeholders)
return _serving_input_receiver_fn
@@ -190,7 +298,7 @@ class ARRegressor(TimeSeriesRegressor):
def __init__(
self, periodicities, input_window_size, output_window_size,
- num_features, num_time_buckets=10,
+ num_features, exogenous_feature_columns=None, num_time_buckets=10,
loss=ar_model.ARModel.NORMAL_LIKELIHOOD_LOSS, hidden_layer_sizes=None,
anomaly_prior_probability=None, anomaly_distribution=None,
optimizer=None, model_dir=None, config=None):
@@ -205,7 +313,12 @@ class ARRegressor(TimeSeriesRegressor):
output_window_size: Number of future time steps to predict. Note that
setting it to > 1 empirically seems to give a better fit.
num_features: The dimensionality of the time series (one for univariate,
- more than one for multivariate).
+ more than one for multivariate).
+ exogenous_feature_columns: A list of `tf.feature_column`s (for example
+ `tf.feature_column.embedding_column`) corresponding to exogenous
+ features which provide extra information to the model but are not part
+ of the series to be predicted. Passed to
+ `tf.feature_column.input_layer`.
num_time_buckets: Number of buckets into which to divide (time %
periodicity) for generating time based features.
loss: Loss function to use for training. Currently supported values are
@@ -241,10 +354,13 @@ class ARRegressor(TimeSeriesRegressor):
anomaly_distribution = ar_model.AnomalyMixtureARModel.GAUSSIAN_ANOMALY
model = ar_model.ARModel(
periodicities=periodicities, num_features=num_features,
+ prediction_model_factory=functools.partial(
+ ar_model.FlatPredictionModel,
+ hidden_layer_sizes=hidden_layer_sizes),
+ exogenous_feature_columns=exogenous_feature_columns,
num_time_buckets=num_time_buckets,
input_window_size=input_window_size,
- output_window_size=output_window_size, loss=loss,
- hidden_layer_sizes=hidden_layer_sizes)
+ output_window_size=output_window_size, loss=loss)
else:
if loss != ar_model.ARModel.NORMAL_LIKELIHOOD_LOSS:
raise ValueError(
@@ -255,8 +371,11 @@ class ARRegressor(TimeSeriesRegressor):
input_window_size=input_window_size,
output_window_size=output_window_size,
num_features=num_features,
+ prediction_model_factory=functools.partial(
+ ar_model.FlatPredictionModel,
+ hidden_layer_sizes=hidden_layer_sizes),
+ exogenous_feature_columns=exogenous_feature_columns,
num_time_buckets=num_time_buckets,
- hidden_layer_sizes=hidden_layer_sizes,
anomaly_prior_probability=anomaly_prior_probability,
anomaly_distribution=anomaly_distribution)
state_manager = state_management.FilteringOnlyStateManager()
@@ -268,11 +387,167 @@ class ARRegressor(TimeSeriesRegressor):
config=config)
+# TODO(b/113684821): Add detailed documentation on what the input_fn should do.
+# Add an example of making and returning a Dataset object. Determine if
+# endogenous features can be passed in as FeatureColumns. Move ARModel's loss
+# functions into a more general location.
+class LSTMAutoRegressor(TimeSeriesRegressor):
+ """An Estimator for an LSTM autoregressive model.
+
+ LSTMAutoRegressor is a window-based model, inputting fixed windows of length
+ `input_window_size` and outputting fixed windows of length
+ `output_window_size`. These two parameters must add up to the window_size
+ of data returned by the `input_fn`.
+
+ Each periodicity in the `periodicities` arg is divided by the `num_timesteps`
+ into timesteps that are represented as time features added to the model.
+
+ A good heuristic for picking an appropriate periodicity for a given data set
+ would be the length of cycles in the data. For example, energy usage in a
+ home is typically cyclic each day. If the time feature in a home energy
+ usage dataset is in the unit of hours, then 24 would be an appropriate
+ periodicity. Similarly, a good heuristic for `num_timesteps` is how often the
+ data is expected to change within the cycle. For the aforementioned home
+ energy usage dataset and periodicity of 24, then 48 would be a reasonable
+ value if usage is expected to change every half hour.
+
+ Each feature's value for a given example with time t is the difference
+ between t and the start of the timestep it falls under. If it doesn't fall
+ under a feature's associated timestep, then that feature's value is zero.
+
+ For example: if `periodicities` = (9, 12) and `num_timesteps` = 3, then 6
+ features would be added to the model, 3 for periodicity 9 and 3 for
+ periodicity 12.
+
+ For an example data point where t = 17:
+ - It's in the 3rd timestep for periodicity 9 (2nd period is 9-18 and 3rd
+ timestep is 15-18)
+ - It's in the 2nd timestep for periodicity 12 (2nd period is 12-24 and
+ 2nd timestep is between 16-20).
+
+ Therefore the 6 added features for this row with t = 17 would be:
+
+ # Feature name (periodicity#_timestep#), feature value
+ P9_T1, 0 # not in first timestep
+ P9_T2, 0 # not in second timestep
+ P9_T3, 2 # 17 - 15 since 15 is the start of the 3rd timestep
+ P12_T1, 0 # not in first timestep
+ P12_T2, 1 # 17 - 16 since 16 is the start of the 2nd timestep
+ P12_T3, 0 # not in third timestep
+
+ Example Code:
+
+ ```python
+ extra_feature_columns = (
+ feature_column.numeric_column("exogenous_variable"),
+ )
+
+ estimator = LSTMAutoRegressor(
+ periodicities=10,
+ input_window_size=10,
+ output_window_size=5,
+ model_dir="/path/to/model/dir",
+ num_features=1,
+ extra_feature_columns=extra_feature_columns,
+ num_timesteps=50,
+ num_units=10,
+ optimizer=tf.train.ProximalAdagradOptimizer(...))
+
+ # Input builders
+ def input_fn_train():
+ return {
+ "times": tf.range(15)[None, :],
+ "values": tf.random_normal(shape=[1, 15, 1])
+ }
+ estimator.train(input_fn=input_fn_train, steps=100)
+
+ def input_fn_eval():
+ pass
+ metrics = estimator.evaluate(input_fn=input_fn_eval, steps=10)
+
+ def input_fn_predict():
+ pass
+ predictions = estimator.predict(input_fn=input_fn_predict)
+ ```
+ """
+
+ def __init__(self,
+ periodicities,
+ input_window_size,
+ output_window_size,
+ model_dir=None,
+ num_features=1,
+ extra_feature_columns=None,
+ num_timesteps=10,
+ loss=ar_model.ARModel.NORMAL_LIKELIHOOD_LOSS,
+ num_units=128,
+ optimizer="Adam",
+ config=None):
+ """Initialize the Estimator.
+
+ Args:
+ periodicities: periodicities of the input data, in the same units as the
+ time feature (for example 24 if feeding hourly data with a daily
+ periodicity, or 60 * 24 if feeding minute-level data with daily
+ periodicity). Note this can be a single value or a list of values for
+ multiple periodicities.
+ input_window_size: Number of past time steps of data to look at when doing
+ the regression.
+ output_window_size: Number of future time steps to predict. Note that
+ setting this value to > 1 empirically seems to give a better fit.
+ model_dir: Directory to save model parameters, graph and etc. This can
+ also be used to load checkpoints from the directory into a estimator
+ to continue training a previously saved model.
+ num_features: The dimensionality of the time series (default value is
+ one for univariate, more than one for multivariate).
+ extra_feature_columns: A list of `tf.feature_column`s (for example
+ `tf.feature_column.embedding_column`) corresponding to features which
+ provide extra information to the model but are not part of the series to
+ be predicted.
+ num_timesteps: Number of buckets into which to divide (time %
+ periodicity). This value multiplied by the number of periodicities is
+ the number of time features added to the model.
+ loss: Loss function to use for training. Currently supported values are
+ SQUARED_LOSS and NORMAL_LIKELIHOOD_LOSS. Note that for
+ NORMAL_LIKELIHOOD_LOSS, we train the covariance term as well. For
+ SQUARED_LOSS, the evaluation loss is reported based on un-scaled
+ observations and predictions, while the training loss is computed on
+ normalized data.
+ num_units: The size of the hidden state in the encoder and decoder LSTM
+ cells.
+ optimizer: string, `tf.train.Optimizer` object, or callable that defines
+ the optimizer algorithm to use for training. Defaults to the Adam
+ optimizer with a learning rate of 0.01.
+ config: Optional `estimator.RunConfig` object to configure the runtime
+ settings.
+ """
+ optimizer = optimizers.get_optimizer_instance(
+ optimizer, learning_rate=0.01)
+ model = ar_model.ARModel(
+ periodicities=periodicities,
+ input_window_size=input_window_size,
+ output_window_size=output_window_size,
+ num_features=num_features,
+ exogenous_feature_columns=extra_feature_columns,
+ num_time_buckets=num_timesteps,
+ loss=loss,
+ prediction_model_factory=functools.partial(
+ ar_model.LSTMPredictionModel, num_units=num_units))
+ state_manager = state_management.FilteringOnlyStateManager()
+ super(LSTMAutoRegressor, self).__init__(
+ model=model,
+ state_manager=state_manager,
+ optimizer=optimizer,
+ model_dir=model_dir,
+ config=config,
+ head_type=ts_head_lib.OneShotPredictionHead)
+
+
class StateSpaceRegressor(TimeSeriesRegressor):
"""An Estimator for general state space models."""
def __init__(self, model, state_manager=None, optimizer=None, model_dir=None,
- config=None):
+ config=None, head_type=ts_head_lib.TimeSeriesRegressionHead):
"""See TimeSeriesRegressor. Uses the ChainingStateManager by default."""
if not isinstance(model, state_space_model.StateSpaceModel):
raise ValueError(
@@ -285,7 +560,8 @@ class StateSpaceRegressor(TimeSeriesRegressor):
state_manager=state_manager,
optimizer=optimizer,
model_dir=model_dir,
- config=config)
+ config=config,
+ head_type=head_type)
class StructuralEnsembleRegressor(StateSpaceRegressor):
@@ -328,7 +604,8 @@ class StructuralEnsembleRegressor(StateSpaceRegressor):
anomaly_prior_probability=None,
optimizer=None,
model_dir=None,
- config=None):
+ config=None,
+ head_type=ts_head_lib.TimeSeriesRegressionHead):
"""Initialize the Estimator.
Args:
@@ -385,6 +662,8 @@ class StructuralEnsembleRegressor(StateSpaceRegressor):
from tf.train.Optimizer. Defaults to Adam with step size 0.02.
model_dir: See `Estimator`.
config: See `Estimator`.
+ head_type: The kind of head to use for the model (inheriting from
+ `TimeSeriesRegressionHead`).
"""
if anomaly_prior_probability is not None:
filtering_postprocessor = StateInterpolatingAnomalyDetector(
@@ -408,4 +687,5 @@ class StructuralEnsembleRegressor(StateSpaceRegressor):
model=model,
optimizer=optimizer,
model_dir=model_dir,
- config=config)
+ config=config,
+ head_type=head_type)
diff --git a/tensorflow/contrib/timeseries/python/timeseries/estimators_test.py b/tensorflow/contrib/timeseries/python/timeseries/estimators_test.py
index 9f161c1695..6ec7184c68 100644
--- a/tensorflow/contrib/timeseries/python/timeseries/estimators_test.py
+++ b/tensorflow/contrib/timeseries/python/timeseries/estimators_test.py
@@ -16,6 +16,7 @@ from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
+import functools
import tempfile
import numpy
@@ -29,6 +30,7 @@ from tensorflow.contrib.timeseries.python.timeseries import saved_model_utils
from tensorflow.python.client import session
from tensorflow.python.estimator import estimator_lib
+from tensorflow.python.feature_column import feature_column
from tensorflow.python.framework import dtypes
from tensorflow.python.framework import ops
from tensorflow.python.platform import test
@@ -48,12 +50,17 @@ class TimeSeriesRegressorTest(test.TestCase):
def _fit_restore_fit_test_template(self, estimator_fn, dtype):
"""Tests restoring previously fit models."""
model_dir = tempfile.mkdtemp(dir=self.get_temp_dir())
- first_estimator = estimator_fn(model_dir)
+ exogenous_feature_columns = (
+ feature_column.numeric_column("exogenous"),
+ )
+ first_estimator = estimator_fn(model_dir, exogenous_feature_columns)
times = numpy.arange(20, dtype=numpy.int64)
values = numpy.arange(20, dtype=dtype.as_numpy_dtype)
+ exogenous = numpy.arange(20, dtype=dtype.as_numpy_dtype)
features = {
feature_keys.TrainEvalFeatures.TIMES: times,
- feature_keys.TrainEvalFeatures.VALUES: values
+ feature_keys.TrainEvalFeatures.VALUES: values,
+ "exogenous": exogenous
}
train_input_fn = input_pipeline.RandomWindowInputFn(
input_pipeline.NumpyReader(features), shuffle_seed=2, num_threads=1,
@@ -61,21 +68,29 @@ class TimeSeriesRegressorTest(test.TestCase):
eval_input_fn = input_pipeline.RandomWindowInputFn(
input_pipeline.NumpyReader(features), shuffle_seed=3, num_threads=1,
batch_size=16, window_size=16)
- first_estimator.train(input_fn=train_input_fn, steps=5)
- first_loss_before_fit = first_estimator.evaluate(
- input_fn=eval_input_fn, steps=1)["loss"]
- first_estimator.train(input_fn=train_input_fn, steps=50)
+ first_estimator.train(input_fn=train_input_fn, steps=1)
+ first_evaluation = first_estimator.evaluate(
+ input_fn=eval_input_fn, steps=1)
+ first_loss_before_fit = first_evaluation["loss"]
+ self.assertAllEqual(first_loss_before_fit, first_evaluation["average_loss"])
+ self.assertAllEqual([], first_loss_before_fit.shape)
+ first_estimator.train(input_fn=train_input_fn, steps=1)
first_loss_after_fit = first_estimator.evaluate(
input_fn=eval_input_fn, steps=1)["loss"]
- self.assertLess(first_loss_after_fit, first_loss_before_fit)
- second_estimator = estimator_fn(model_dir)
- second_estimator.train(input_fn=train_input_fn, steps=2)
+ self.assertAllEqual([], first_loss_after_fit.shape)
+ second_estimator = estimator_fn(model_dir, exogenous_feature_columns)
+ second_estimator.train(input_fn=train_input_fn, steps=1)
whole_dataset_input_fn = input_pipeline.WholeDatasetInputFn(
input_pipeline.NumpyReader(features))
whole_dataset_evaluation = second_estimator.evaluate(
input_fn=whole_dataset_input_fn, steps=1)
+ exogenous_values_ten_steps = {
+ "exogenous": numpy.arange(
+ 10, dtype=dtype.as_numpy_dtype)[None, :, None]
+ }
predict_input_fn = input_pipeline.predict_continuation_input_fn(
evaluation=whole_dataset_evaluation,
+ exogenous_features=exogenous_values_ten_steps,
steps=10)
# Also tests that limit_epochs in predict_continuation_input_fn prevents
# infinite iteration
@@ -92,6 +107,7 @@ class TimeSeriesRegressorTest(test.TestCase):
saved_prediction = saved_model_utils.predict_continuation(
continue_from=whole_dataset_evaluation,
steps=10,
+ exogenous_features=exogenous_values_ten_steps,
signatures=signatures,
session=sess)
# Saved model predictions should be the same as Estimator predictions
@@ -104,7 +120,8 @@ class TimeSeriesRegressorTest(test.TestCase):
continue_from=whole_dataset_evaluation,
features={
feature_keys.FilteringFeatures.TIMES: times[None, -1] + 2,
- feature_keys.FilteringFeatures.VALUES: values[None, -1] + 2.
+ feature_keys.FilteringFeatures.VALUES: values[None, -1] + 2.,
+ "exogenous": values[None, -1, None] + 12.
},
signatures=signatures,
session=sess)
@@ -112,6 +129,10 @@ class TimeSeriesRegressorTest(test.TestCase):
second_saved_prediction = saved_model_utils.predict_continuation(
continue_from=first_filtering,
steps=1,
+ exogenous_features={
+ "exogenous": numpy.arange(
+ 1, dtype=dtype.as_numpy_dtype)[None, :, None]
+ },
signatures=signatures,
session=sess)
self.assertEqual(
@@ -122,7 +143,8 @@ class TimeSeriesRegressorTest(test.TestCase):
continue_from=first_filtering,
features={
feature_keys.FilteringFeatures.TIMES: times[-1] + 3,
- feature_keys.FilteringFeatures.VALUES: values[-1] + 3.
+ feature_keys.FilteringFeatures.VALUES: values[-1] + 3.,
+ "exogenous": values[-1, None] + 13.
},
signatures=signatures,
session=sess)
@@ -131,7 +153,8 @@ class TimeSeriesRegressorTest(test.TestCase):
six.assertCountEqual(
self,
[feature_keys.FilteringFeatures.TIMES,
- feature_keys.FilteringFeatures.VALUES],
+ feature_keys.FilteringFeatures.VALUES,
+ "exogenous"],
signatures.signature_def[
feature_keys.SavedModelLabels.COLD_START_FILTER].inputs.keys())
batch_numpy_times = numpy.tile(
@@ -142,7 +165,8 @@ class TimeSeriesRegressorTest(test.TestCase):
session=sess,
features={
feature_keys.FilteringFeatures.TIMES: batch_numpy_times,
- feature_keys.FilteringFeatures.VALUES: batch_numpy_values
+ feature_keys.FilteringFeatures.VALUES: batch_numpy_values,
+ "exogenous": 10. + batch_numpy_values
}
)
predict_times = numpy.tile(
@@ -150,28 +174,92 @@ class TimeSeriesRegressorTest(test.TestCase):
predictions = saved_model_utils.predict_continuation(
continue_from=state,
times=predict_times,
+ exogenous_features={
+ "exogenous": numpy.tile(numpy.arange(
+ 15, dtype=dtype.as_numpy_dtype), (10,))[None, :, None]
+ },
signatures=signatures,
session=sess)
self.assertAllEqual([10, 15, 1], predictions["mean"].shape)
- def test_fit_restore_fit_ar_regressor(self):
- def _estimator_fn(model_dir):
+ def test_fit_restore_fit_ar_flat(self):
+ def _estimator_fn(model_dir, exogenous_feature_columns):
return estimators.ARRegressor(
periodicities=10, input_window_size=10, output_window_size=6,
num_features=1, model_dir=model_dir, config=_SeedRunConfig(),
# This test is flaky with normal likelihood loss (could add more
# training iterations instead).
- loss=ar_model.ARModel.SQUARED_LOSS)
+ loss=ar_model.ARModel.SQUARED_LOSS,
+ exogenous_feature_columns=exogenous_feature_columns)
+ self._fit_restore_fit_test_template(_estimator_fn, dtype=dtypes.float32)
+
+ def test_fit_restore_fit_ar_lstm(self):
+ def _estimator_fn(model_dir, exogenous_feature_columns):
+ return estimators.TimeSeriesRegressor(
+ model=ar_model.ARModel(
+ periodicities=10, input_window_size=10, output_window_size=6,
+ num_features=1,
+ exogenous_feature_columns=exogenous_feature_columns,
+ prediction_model_factory=functools.partial(
+ ar_model.LSTMPredictionModel,
+ num_units=10)),
+ config=_SeedRunConfig(),
+ model_dir=model_dir)
self._fit_restore_fit_test_template(_estimator_fn, dtype=dtypes.float32)
def test_fit_restore_fit_structural_ensemble_regressor(self):
dtype = dtypes.float32
- def _estimator_fn(model_dir):
+ def _estimator_fn(model_dir, exogenous_feature_columns):
return estimators.StructuralEnsembleRegressor(
num_features=1, periodicities=10, model_dir=model_dir, dtype=dtype,
- config=_SeedRunConfig())
+ config=_SeedRunConfig(),
+ exogenous_feature_columns=exogenous_feature_columns)
self._fit_restore_fit_test_template(_estimator_fn, dtype=dtype)
+ def test_structural_ensemble_numpy_input(self):
+ numpy_data = {"times": numpy.arange(50),
+ "values": numpy.random.normal(size=[50])}
+ estimators.StructuralEnsembleRegressor(
+ num_features=1, periodicities=[], model_dir=self.get_temp_dir(),
+ config=_SeedRunConfig()).train(
+ input_pipeline.WholeDatasetInputFn(
+ input_pipeline.NumpyReader(numpy_data)),
+ steps=1)
+
+ def test_ar_lstm_regressor(self):
+ dtype = dtypes.float32
+ model_dir = tempfile.mkdtemp(dir=self.get_temp_dir())
+ exogenous_feature_columns = (
+ feature_column.numeric_column("exogenous"),
+ )
+ estimator = estimators.LSTMAutoRegressor(
+ periodicities=10,
+ input_window_size=10,
+ output_window_size=6,
+ model_dir=model_dir,
+ num_features=1,
+ extra_feature_columns=exogenous_feature_columns,
+ num_units=10,
+ config=_SeedRunConfig())
+ times = numpy.arange(20, dtype=numpy.int64)
+ values = numpy.arange(20, dtype=dtype.as_numpy_dtype)
+ exogenous = numpy.arange(20, dtype=dtype.as_numpy_dtype)
+ features = {
+ feature_keys.TrainEvalFeatures.TIMES: times,
+ feature_keys.TrainEvalFeatures.VALUES: values,
+ "exogenous": exogenous
+ }
+ train_input_fn = input_pipeline.RandomWindowInputFn(
+ input_pipeline.NumpyReader(features), shuffle_seed=2, num_threads=1,
+ batch_size=16, window_size=16)
+ eval_input_fn = input_pipeline.RandomWindowInputFn(
+ input_pipeline.NumpyReader(features), shuffle_seed=3, num_threads=1,
+ batch_size=16, window_size=16)
+ estimator.train(input_fn=train_input_fn, steps=1)
+ evaluation = estimator.evaluate(
+ input_fn=eval_input_fn, steps=1)
+ self.assertAllEqual(evaluation["loss"], evaluation["average_loss"])
+ self.assertAllEqual([], evaluation["loss"].shape)
if __name__ == "__main__":
test.main()
diff --git a/tensorflow/contrib/timeseries/python/timeseries/head.py b/tensorflow/contrib/timeseries/python/timeseries/head.py
index a28a5872b8..1f9f9b7aa6 100644
--- a/tensorflow/contrib/timeseries/python/timeseries/head.py
+++ b/tensorflow/contrib/timeseries/python/timeseries/head.py
@@ -19,24 +19,23 @@ from __future__ import print_function
import re
-from tensorflow.python.training import training_util
-from tensorflow.contrib.layers.python.layers import optimizers
-
from tensorflow.contrib.timeseries.python.timeseries import feature_keys
-
from tensorflow.python.estimator import estimator_lib
from tensorflow.python.estimator.canned import head as head_lib
from tensorflow.python.estimator.canned import metric_keys
from tensorflow.python.estimator.export import export_lib
from tensorflow.python.framework import dtypes
from tensorflow.python.framework import ops
+from tensorflow.python.framework import sparse_tensor
from tensorflow.python.ops import array_ops
from tensorflow.python.ops import control_flow_ops
from tensorflow.python.ops import math_ops
+from tensorflow.python.ops import metrics_impl
from tensorflow.python.ops import state_ops
from tensorflow.python.ops import variable_scope
-from tensorflow.python.util import nest
from tensorflow.python.summary import summary
+from tensorflow.python.training import training_util
+from tensorflow.python.util import nest
class _NoStatePredictOutput(export_lib.PredictOutput):
@@ -102,12 +101,9 @@ class TimeSeriesRegressionHead(head_lib._Head): # pylint:disable=protected-acce
use_resource=True):
model_outputs = self.create_loss(features, mode)
- train_op = optimizers.optimize_loss(
+ train_op = self.optimizer.minimize(
model_outputs.loss,
- global_step=training_util.get_global_step(),
- optimizer=self.optimizer,
- # Learning rate is set in the Optimizer object
- learning_rate=None)
+ global_step=training_util.get_global_step())
return estimator_lib.EstimatorSpec(
loss=model_outputs.loss,
mode=mode,
@@ -128,11 +124,14 @@ class TimeSeriesRegressionHead(head_lib._Head): # pylint:disable=protected-acce
metrics[feature_keys.FilteringResults.STATE_TUPLE] = (
_identity_metric_nested(feature_keys.FilteringResults.STATE_TUPLE,
model_outputs.end_state))
+ metrics[metric_keys.MetricKeys.LOSS_MEAN] = metrics_impl.mean(
+ model_outputs.loss, name="average_loss")
return estimator_lib.EstimatorSpec(
loss=model_outputs.loss,
mode=mode,
eval_metric_ops=metrics,
- predictions={})
+ # needed for custom metrics.
+ predictions=model_outputs.predictions)
def _predict_ops(self, features):
"""Add ops for prediction to the graph."""
@@ -185,7 +184,7 @@ class TimeSeriesRegressionHead(head_lib._Head): # pylint:disable=protected-acce
return math_ops.cast(value, self.model.dtype)
if name == feature_keys.PredictionFeatures.STATE_TUPLE:
return value # Correct dtypes are model-dependent
- return ops.convert_to_tensor(value)
+ return sparse_tensor.convert_to_tensor_or_sparse_tensor(value)
def _gather_state(self, features):
"""Returns `features` with state packed, indicates if packing was done."""
@@ -207,15 +206,38 @@ class TimeSeriesRegressionHead(head_lib._Head): # pylint:disable=protected-acce
flat_sequence=[tensor for _, _, tensor in numbered_state])
return features, True
+ def _check_predict_features(self, features):
+ """Raises errors if features are not suitable for prediction."""
+ if feature_keys.PredictionFeatures.TIMES not in features:
+ raise ValueError("Expected a '{}' feature for prediction.".format(
+ feature_keys.PredictionFeatures.TIMES))
+ if feature_keys.PredictionFeatures.STATE_TUPLE not in features:
+ raise ValueError("Expected a '{}' feature for prediction.".format(
+ feature_keys.PredictionFeatures.STATE_TUPLE))
+ times_feature = features[feature_keys.PredictionFeatures.TIMES]
+ if not times_feature.get_shape().is_compatible_with([None, None]):
+ raise ValueError(
+ ("Expected shape (batch dimension, window size) for feature '{}' "
+ "(got shape {})").format(feature_keys.PredictionFeatures.TIMES,
+ times_feature.get_shape()))
+ _check_feature_shapes_compatible_with(
+ features=features,
+ compatible_with_name=feature_keys.PredictionFeatures.TIMES,
+ compatible_with_value=times_feature,
+ ignore=set([
+ # Model-dependent shapes
+ feature_keys.PredictionFeatures.STATE_TUPLE
+ ]))
+
def create_estimator_spec(self, features, mode, labels=None):
"""Performs basic error checking and returns an EstimatorSpec."""
with ops.name_scope(self._name, "head"):
- if labels:
+ if labels is not None and labels != {}: # for better error messages.
raise ValueError(
- "The model received a `labels` dictionary, which is "
- "not supported. Pass '{}' and '{}' as "
- "features.".format(feature_keys.TrainEvalFeatures.TIMES,
- feature_keys.TrainEvalFeatures.VALUES))
+ "The model received a `labels`, which is not supported. "
+ "Pass '{}' and '{}' as features.".format(
+ feature_keys.TrainEvalFeatures.TIMES,
+ feature_keys.TrainEvalFeatures.VALUES))
del labels
features = {
name: self._convert_feature_to_tensor(name=name, value=value)
@@ -235,7 +257,7 @@ class TimeSeriesRegressionHead(head_lib._Head): # pylint:disable=protected-acce
mode == estimator_lib.ModeKeys.EVAL):
_check_train_eval_features(features, self.model)
elif mode == estimator_lib.ModeKeys.PREDICT:
- _check_predict_features(features)
+ self._check_predict_features(features)
else:
raise ValueError("Unknown mode '{}' passed to model_fn.".format(mode))
@@ -272,6 +294,44 @@ class OneShotPredictionHead(TimeSeriesRegressionHead):
each time predictions are requested when using this head.
"""
+ def _check_predict_features(self, features):
+ """Raises errors if features are not suitable for one-shot prediction."""
+ if feature_keys.PredictionFeatures.TIMES not in features:
+ raise ValueError("Expected a '{}' feature for prediction.".format(
+ feature_keys.PredictionFeatures.TIMES))
+ if feature_keys.TrainEvalFeatures.VALUES not in features:
+ raise ValueError("Expected a '{}' feature for prediction.".format(
+ feature_keys.TrainEvalFeatures.VALUES))
+ if feature_keys.PredictionFeatures.STATE_TUPLE not in features:
+ raise ValueError("Expected a '{}' feature for prediction.".format(
+ feature_keys.PredictionFeatures.STATE_TUPLE))
+ times_feature = features[feature_keys.PredictionFeatures.TIMES]
+ if not times_feature.get_shape().is_compatible_with([None, None]):
+ raise ValueError(
+ ("Expected shape (batch dimension, window size) for feature '{}' "
+ "(got shape {})").format(feature_keys.PredictionFeatures.TIMES,
+ times_feature.get_shape()))
+ _check_feature_shapes_compatible_with(
+ features=features,
+ compatible_with_name=feature_keys.PredictionFeatures.TIMES,
+ compatible_with_value=times_feature,
+ ignore=set([
+ # Model-dependent shapes
+ feature_keys.PredictionFeatures.STATE_TUPLE,
+ # One shot prediction head relies on values being shorter than
+ # times. Even though we're predicting eventually, we need values for
+ # the filtering phase.
+ feature_keys.TrainEvalFeatures.VALUES,
+ ]))
+
+ def _evaluate_ops(self, features):
+ """Add ops for evaluation (aka filtering) to the graph."""
+ spec = super(OneShotPredictionHead, self)._evaluate_ops(features)
+ # No state is fed to OneShotPredictionHead, so we don't return it; it being
+ # a tuple can cause issues for downstream infrastructure.
+ del spec.eval_metric_ops[feature_keys.State.STATE_TUPLE]
+ return spec
+
def _serving_ops(self, features):
"""Add ops for serving to the graph."""
with variable_scope.variable_scope("model", use_resource=True):
@@ -338,29 +398,6 @@ def _check_feature_shapes_compatible_with(features,
times_shape=compatible_with_value.get_shape()))
-def _check_predict_features(features):
- """Raises errors if features are not suitable for prediction."""
- if feature_keys.PredictionFeatures.TIMES not in features:
- raise ValueError("Expected a '{}' feature for prediction.".format(
- feature_keys.PredictionFeatures.TIMES))
- if feature_keys.PredictionFeatures.STATE_TUPLE not in features:
- raise ValueError("Expected a '{}' feature for prediction.".format(
- feature_keys.PredictionFeatures.STATE_TUPLE))
- times_feature = features[feature_keys.PredictionFeatures.TIMES]
- if not times_feature.get_shape().is_compatible_with([None, None]):
- raise ValueError(
- ("Expected shape (batch dimension, window size) for feature '{}' "
- "(got shape {})").format(feature_keys.PredictionFeatures.TIMES,
- times_feature.get_shape()))
- _check_feature_shapes_compatible_with(
- features=features,
- compatible_with_name=feature_keys.PredictionFeatures.TIMES,
- compatible_with_value=times_feature,
- ignore=set([
- feature_keys.PredictionFeatures.STATE_TUPLE # Model-dependent shapes
- ]))
-
-
def _check_train_eval_features(features, model):
"""Raise errors if features are not suitable for training/evaluation."""
if feature_keys.TrainEvalFeatures.TIMES not in features:
diff --git a/tensorflow/contrib/timeseries/python/timeseries/head_test.py b/tensorflow/contrib/timeseries/python/timeseries/head_test.py
index c606db76a6..647455ae42 100644
--- a/tensorflow/contrib/timeseries/python/timeseries/head_test.py
+++ b/tensorflow/contrib/timeseries/python/timeseries/head_test.py
@@ -18,16 +18,23 @@ from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
+import functools
+import os
+
+from absl.testing import parameterized
import numpy
import six
+from tensorflow.contrib.estimator.python.estimator import extenders
from tensorflow.contrib.timeseries.examples import lstm as lstm_example
+from tensorflow.contrib.timeseries.python.timeseries import ar_model
from tensorflow.contrib.timeseries.python.timeseries import estimators as ts_estimators
from tensorflow.contrib.timeseries.python.timeseries import feature_keys
from tensorflow.contrib.timeseries.python.timeseries import head as ts_head_lib
from tensorflow.contrib.timeseries.python.timeseries import input_pipeline
from tensorflow.contrib.timeseries.python.timeseries import model
from tensorflow.contrib.timeseries.python.timeseries import state_management
+from tensorflow.core.example import example_pb2
from tensorflow.python.client import session as session_lib
from tensorflow.python.estimator import estimator_lib
@@ -35,6 +42,7 @@ from tensorflow.python.feature_column import feature_column
from tensorflow.python.framework import dtypes
from tensorflow.python.framework import ops
from tensorflow.python.ops import array_ops
+from tensorflow.python.ops import control_flow_ops
from tensorflow.python.ops import math_ops
from tensorflow.python.ops import metrics
from tensorflow.python.ops import variables
@@ -53,9 +61,12 @@ class HeadTest(test.TestCase):
model_fn = _stub_model_fn()
for mode in [estimator_lib.ModeKeys.TRAIN, estimator_lib.ModeKeys.EVAL,
estimator_lib.ModeKeys.PREDICT]:
- with self.assertRaisesRegexp(ValueError, "labels"):
+ with self.assertRaisesRegexp(ValueError, "received a `labels`"):
model_fn(features={}, labels={"a": "b"}, mode=mode)
+ with self.assertRaisesRegexp(ValueError, "received a `labels`"):
+ model_fn(features={}, labels=array_ops.zeros([]), mode=mode)
+
def test_unknown_mode(self):
model_fn = _stub_model_fn()
with self.assertRaisesRegexp(ValueError, "Unknown mode 'Not a mode'"):
@@ -111,7 +122,7 @@ class EvaluationMetricsTests(test.TestCase):
metric[1] for metric in outputs.eval_metric_ops.values()]
loss_mean, loss_update = metrics.mean(outputs.loss)
metric_update_ops.append(loss_update)
- with self.test_session() as sess:
+ with self.cached_session() as sess:
coordinator = coordinator_lib.Coordinator()
queue_runner_impl.start_queue_runners(sess, coord=coordinator)
variables.local_variables_initializer().run()
@@ -128,6 +139,45 @@ class EvaluationMetricsTests(test.TestCase):
coordinator.request_stop()
coordinator.join()
+ def test_custom_metrics(self):
+ """Tests that the custom metrics can be applied to the estimator."""
+ model_dir = self.get_temp_dir()
+ estimator = ts_estimators.TimeSeriesRegressor(
+ model=lstm_example._LSTMModel(num_features=1, num_units=4),
+ optimizer=adam.AdamOptimizer(0.001),
+ config=estimator_lib.RunConfig(tf_random_seed=4),
+ model_dir=model_dir)
+
+ def input_fn():
+ return {
+ feature_keys.TrainEvalFeatures.TIMES: [[1, 2, 3], [7, 8, 9]],
+ feature_keys.TrainEvalFeatures.VALUES:
+ numpy.array([[[0.], [1.], [0.]], [[2.], [3.], [2.]]])
+ }
+
+ def metrics_fn(predictions, features):
+ # checking that the inputs are properly passed.
+ predict = predictions["mean"]
+ target = features[feature_keys.TrainEvalFeatures.VALUES][:, -1, 0]
+ return {
+ "plain_boring_metric386":
+ (math_ops.reduce_mean(math_ops.abs(predict - target)),
+ control_flow_ops.no_op()),
+ "fun_metric101": (math_ops.reduce_sum(predict + target),
+ control_flow_ops.no_op()),
+ }
+
+ # Evaluation without training is enough for testing custom metrics.
+ estimator = extenders.add_metrics(estimator, metrics_fn)
+ evaluation = estimator.evaluate(input_fn, steps=1)
+ self.assertIn("plain_boring_metric386", evaluation)
+ self.assertIn("fun_metric101", evaluation)
+ self.assertIn("average_loss", evaluation)
+ # The values are deterministic because of fixed tf_random_seed.
+ # However if they become flaky, remove such exacts comparisons.
+ self.assertAllClose(evaluation["plain_boring_metric386"], 1.130380)
+ self.assertAllClose(evaluation["fun_metric101"], 10.435442)
+
class _StubModel(object):
num_features = 3
@@ -274,10 +324,56 @@ class PredictFeatureCheckingTests(test.TestCase):
mode=estimator_lib.ModeKeys.PREDICT)
-class OneShotTests(test.TestCase):
-
- def test_one_shot_prediction_head_export(self):
- model_dir = self.get_temp_dir()
+def _custom_time_series_regressor(
+ model_dir, head_type, exogenous_feature_columns):
+ return ts_estimators.TimeSeriesRegressor(
+ model=lstm_example._LSTMModel(
+ num_features=5, num_units=128,
+ exogenous_feature_columns=exogenous_feature_columns),
+ optimizer=adam.AdamOptimizer(0.001),
+ config=estimator_lib.RunConfig(tf_random_seed=4),
+ state_manager=state_management.ChainingStateManager(),
+ head_type=head_type,
+ model_dir=model_dir)
+
+
+def _structural_ensemble_regressor(
+ model_dir, head_type, exogenous_feature_columns):
+ return ts_estimators.StructuralEnsembleRegressor(
+ periodicities=None,
+ num_features=5,
+ exogenous_feature_columns=exogenous_feature_columns,
+ head_type=head_type,
+ model_dir=model_dir)
+
+
+def _ar_lstm_regressor(
+ model_dir, head_type, exogenous_feature_columns):
+ return ts_estimators.TimeSeriesRegressor(
+ model=ar_model.ARModel(
+ periodicities=10, input_window_size=10, output_window_size=6,
+ num_features=5,
+ exogenous_feature_columns=exogenous_feature_columns,
+ prediction_model_factory=functools.partial(
+ ar_model.LSTMPredictionModel,
+ num_units=10)),
+ head_type=head_type,
+ model_dir=model_dir)
+
+
+class OneShotTests(parameterized.TestCase):
+
+ @parameterized.named_parameters(
+ {"testcase_name": "ar_lstm_regressor",
+ "estimator_factory": _ar_lstm_regressor},
+ {"testcase_name": "custom_time_series_regressor",
+ "estimator_factory": _custom_time_series_regressor},
+ {"testcase_name": "structural_ensemble_regressor",
+ "estimator_factory": _structural_ensemble_regressor})
+ def test_one_shot_prediction_head_export(self, estimator_factory):
+ def _new_temp_dir():
+ return os.path.join(test.get_temp_dir(), str(ops.uid()))
+ model_dir = _new_temp_dir()
categorical_column = feature_column.categorical_column_with_hash_bucket(
key="categorical_exogenous_feature", hash_bucket_size=16)
exogenous_feature_columns = [
@@ -285,15 +381,10 @@ class OneShotTests(test.TestCase):
"2d_exogenous_feature", shape=(2,)),
feature_column.embedding_column(
categorical_column=categorical_column, dimension=10)]
- estimator = ts_estimators.TimeSeriesRegressor(
- model=lstm_example._LSTMModel(
- num_features=5, num_units=128,
- exogenous_feature_columns=exogenous_feature_columns),
- optimizer=adam.AdamOptimizer(0.001),
- config=estimator_lib.RunConfig(tf_random_seed=4),
- state_manager=state_management.ChainingStateManager(),
- head_type=ts_head_lib.OneShotPredictionHead,
- model_dir=model_dir)
+ estimator = estimator_factory(
+ model_dir=model_dir,
+ exogenous_feature_columns=exogenous_feature_columns,
+ head_type=ts_head_lib.OneShotPredictionHead)
train_features = {
feature_keys.TrainEvalFeatures.TIMES: numpy.arange(
20, dtype=numpy.int64),
@@ -307,8 +398,11 @@ class OneShotTests(test.TestCase):
input_pipeline.NumpyReader(train_features), shuffle_seed=2,
num_threads=1, batch_size=16, window_size=16)
estimator.train(input_fn=train_input_fn, steps=5)
+ result = estimator.evaluate(input_fn=train_input_fn, steps=1)
+ self.assertIn("average_loss", result)
+ self.assertNotIn(feature_keys.State.STATE_TUPLE, result)
input_receiver_fn = estimator.build_raw_serving_input_receiver_fn()
- export_location = estimator.export_savedmodel(self.get_temp_dir(),
+ export_location = estimator.export_savedmodel(_new_temp_dir(),
input_receiver_fn)
graph = ops.Graph()
with graph.as_default():
@@ -342,7 +436,42 @@ class OneShotTests(test.TestCase):
for output_key, output_value
in predict_signature.outputs.items()}
output = session.run(fetches, feed_dict=feeds)
- self.assertAllEqual((2, 15, 5), output["mean"].shape)
+ self.assertEqual((2, 15, 5), output["mean"].shape)
+ # Build a parsing input function, then make a tf.Example for it to parse.
+ export_location = estimator.export_savedmodel(
+ _new_temp_dir(),
+ estimator.build_one_shot_parsing_serving_input_receiver_fn(
+ filtering_length=20, prediction_length=15))
+ graph = ops.Graph()
+ with graph.as_default():
+ with session_lib.Session() as session:
+ example = example_pb2.Example()
+ times = example.features.feature[feature_keys.TrainEvalFeatures.TIMES]
+ values = example.features.feature[feature_keys.TrainEvalFeatures.VALUES]
+ times.int64_list.value.extend(range(35))
+ for i in range(20):
+ values.float_list.value.extend(
+ [float(i) * 2. + feature_number
+ for feature_number in range(5)])
+ real_feature = example.features.feature["2d_exogenous_feature"]
+ categortical_feature = example.features.feature[
+ "categorical_exogenous_feature"]
+ for i in range(35):
+ real_feature.float_list.value.extend([1, 1])
+ categortical_feature.bytes_list.value.append(b"strkey")
+ # Serialize the tf.Example for feeding to the Session
+ examples = [example.SerializeToString()] * 2
+ signatures = loader.load(
+ session, [tag_constants.SERVING], export_location)
+ predict_signature = signatures.signature_def[
+ feature_keys.SavedModelLabels.PREDICT]
+ ((_, input_value),) = predict_signature.inputs.items()
+ feeds = {graph.as_graph_element(input_value.name): examples}
+ fetches = {output_key: graph.as_graph_element(output_value.name)
+ for output_key, output_value
+ in predict_signature.outputs.items()}
+ output = session.run(fetches, feed_dict=feeds)
+ self.assertEqual((2, 15, 5), output["mean"].shape)
if __name__ == "__main__":
diff --git a/tensorflow/contrib/timeseries/python/timeseries/input_pipeline_test.py b/tensorflow/contrib/timeseries/python/timeseries/input_pipeline_test.py
index 703537abf0..f92148b788 100644
--- a/tensorflow/contrib/timeseries/python/timeseries/input_pipeline_test.py
+++ b/tensorflow/contrib/timeseries/python/timeseries/input_pipeline_test.py
@@ -88,7 +88,7 @@ class RandomWindowInputFnTests(test.TestCase):
window_size=window_size, batch_size=batch_size)
result, _ = input_fn()
init_op = variables.local_variables_initializer()
- with self.test_session() as session:
+ with self.cached_session() as session:
coordinator = coordinator_lib.Coordinator()
queue_runner_impl.start_queue_runners(session, coord=coordinator)
session.run(init_op)
@@ -261,7 +261,7 @@ class WholeDatasetInputFnTests(test.TestCase):
def _whole_dataset_input_fn_test_template(
self, time_series_reader, num_features, num_samples):
result, _ = input_pipeline.WholeDatasetInputFn(time_series_reader)()
- with self.test_session() as session:
+ with self.cached_session() as session:
session.run(variables.local_variables_initializer())
coordinator = coordinator_lib.Coordinator()
queue_runner_impl.start_queue_runners(session, coord=coordinator)
@@ -340,7 +340,7 @@ class AllWindowInputFnTests(test.TestCase):
window_size=window_size)
features, _ = input_fn()
init_op = variables.local_variables_initializer()
- with self.test_session() as session:
+ with self.cached_session() as session:
coordinator = coordinator_lib.Coordinator()
queue_runner_impl.start_queue_runners(session, coord=coordinator)
session.run(init_op)
diff --git a/tensorflow/contrib/timeseries/python/timeseries/math_utils.py b/tensorflow/contrib/timeseries/python/timeseries/math_utils.py
index 26793c80bf..03da2b82e5 100644
--- a/tensorflow/contrib/timeseries/python/timeseries/math_utils.py
+++ b/tensorflow/contrib/timeseries/python/timeseries/math_utils.py
@@ -60,7 +60,7 @@ def clip_covariance(
# TODO(allenl): Smarter scaling here so that correlations are preserved when
# fiddling with diagonal elements.
diagonal = array_ops.matrix_diag_part(covariance_matrix)
- maximum = math_ops.reduce_max(diagonal, axis=-1, keep_dims=True)
+ maximum = math_ops.reduce_max(diagonal, axis=-1, keepdims=True)
new_diagonal = gen_math_ops.maximum(
diagonal, maximum / maximum_variance_ratio)
return array_ops.matrix_set_diag(
@@ -896,8 +896,8 @@ class InputStatisticsFromMiniBatch(object):
statistics.total_observation_count,
math_ops.cast(
gen_math_ops.round(
- math_ops.cast(auxiliary_variables.max_time_seen -
- statistics.start_time + 1, self._dtype) /
+ math_ops.cast(max_time_seen_assign -
+ start_time_update + 1, self._dtype) /
inter_observation_duration_estimate), dtypes.int64))
per_chunk_stat_updates = control_flow_ops.group(
overall_feature_mean_update, overall_feature_var_update,
diff --git a/tensorflow/contrib/timeseries/python/timeseries/math_utils_test.py b/tensorflow/contrib/timeseries/python/timeseries/math_utils_test.py
index b9f8620fd8..c0de42b15b 100644
--- a/tensorflow/contrib/timeseries/python/timeseries/math_utils_test.py
+++ b/tensorflow/contrib/timeseries/python/timeseries/math_utils_test.py
@@ -55,7 +55,7 @@ class MathUtilsTest(test.TestCase):
running_sum = running_sum + current_contribution
# pylint: enable=g-no-augmented-assignment
transition_power = numpy.dot(transition, transition_power)
- with self.test_session():
+ with self.cached_session():
self.assertAllClose(result,
math_utils.power_sums_tensor(
array_size, transition, addition).eval())
@@ -66,7 +66,7 @@ class MathUtilsTest(test.TestCase):
result = []
for i in range(powers.shape[0]):
result.append(numpy.linalg.matrix_power(matrix, powers[i]))
- with self.test_session():
+ with self.cached_session():
self.assertAllClose(result,
math_utils.matrix_to_powers(matrix, powers).eval(),
rtol=1e-5,
@@ -78,7 +78,7 @@ class MathUtilsTest(test.TestCase):
result = []
for i in range(batch.shape[0]):
result.append(numpy.linalg.matrix_power(batch[i], powers[i]))
- with self.test_session():
+ with self.cached_session():
# TODO(allenl): Numerical errors seem to be creeping in. Maybe it can be
# made slightly more stable?
self.assertAllClose(result,
@@ -91,7 +91,7 @@ class MathUtilsTest(test.TestCase):
left_transpose = numpy.transpose(left, [0, 2, 1])
right = numpy.random.normal(size=[2, 3]).astype(numpy.float32)
expected_result = numpy.dot(left, right)
- with self.test_session():
+ with self.cached_session():
self.assertAllClose(expected_result,
math_utils.batch_times_matrix(
left, right).eval())
@@ -114,7 +114,7 @@ class MathUtilsTest(test.TestCase):
right_transpose = numpy.transpose(right, [0, 2, 1])
expected_result = numpy.transpose(numpy.dot(right_transpose, left.T),
[0, 2, 1])
- with self.test_session():
+ with self.cached_session():
self.assertAllClose(expected_result,
math_utils.matrix_times_batch(
left, right).eval())
@@ -132,7 +132,7 @@ class MathUtilsTest(test.TestCase):
adj_x=True, adj_y=True).eval())
def test_make_diagonal_undefined_shapes(self):
- with self.test_session():
+ with self.cached_session():
completely_undefined = array_ops.placeholder(dtype=dtypes.float32)
partly_undefined = array_ops.placeholder(
shape=[None, None], dtype=dtypes.float32)
@@ -152,7 +152,7 @@ class MathUtilsTest(test.TestCase):
[5., 6.]]}))
def test_make_diagonal_mostly_defined_shapes(self):
- with self.test_session():
+ with self.cached_session():
mostly_defined = array_ops.placeholder(
shape=[None, 2], dtype=dtypes.float32)
blocked = math_utils.block_diagonal([[[2.]],
@@ -192,7 +192,7 @@ class TestMakeToeplitzMatrix(test.TestCase):
def _test_make_toeplitz_matrix(self, inputs, output_expected):
output_tf = math_utils.make_toeplitz_matrix(inputs)
- with self.test_session() as sess:
+ with self.cached_session() as sess:
output_tf_np = sess.run(output_tf)
self.assertAllClose(output_tf_np, output_expected)
@@ -201,13 +201,13 @@ class TestMakeCovarianceMatrix(test.TestCase):
def test_zero_size_matrix(self):
raw = numpy.zeros([0, 0])
- with self.test_session():
+ with self.cached_session():
constructed = math_utils.sign_magnitude_positive_definite(raw=raw).eval()
self.assertEqual((0, 0), constructed.shape)
def test_sign_magnitude_positive_definite(self):
for dtype in [dtypes.float32, dtypes.float64]:
- with self.test_session():
+ with self.cached_session():
matrix_tensor = math_utils.sign_magnitude_positive_definite(
raw=constant_op.constant([[-1., -2.], [3., 4.]], dtype=dtype),
off_diagonal_scale=constant_op.constant(-1., dtype=dtype),
@@ -230,7 +230,8 @@ class TestLookupTable(test.TestCase):
name="test_lookup")
def stack_tensor(base_tensor):
return array_ops.stack([base_tensor + 1, base_tensor + 2])
- with self.test_session() as session:
+
+ with self.cached_session() as session:
((float_output, double_output), int_output) = session.run(
hash_table.lookup([2, 1, 0]))
def expected_output_before_insert(base_tensor):
@@ -290,7 +291,7 @@ class InputStatisticsTests(test.TestCase):
time_series_reader=input_pipeline.NumpyReader(features))
statistics = stat_object.initialize_graph(
features=input_fn()[0])
- with self.test_session(graph=graph) as session:
+ with self.session(graph=graph) as session:
variables.global_variables_initializer().run()
coordinator = coordinator_lib.Coordinator()
queue_runner_impl.start_queue_runners(session, coord=coordinator)
diff --git a/tensorflow/contrib/timeseries/python/timeseries/model_utils_test.py b/tensorflow/contrib/timeseries/python/timeseries/model_utils_test.py
index cfd31cc70d..a049dbe773 100644
--- a/tensorflow/contrib/timeseries/python/timeseries/model_utils_test.py
+++ b/tensorflow/contrib/timeseries/python/timeseries/model_utils_test.py
@@ -29,7 +29,7 @@ class ModelUtilsTest(test.TestCase):
def test_parameter_switching(self):
parameter = array_ops.constant(5)
overridden_parameter = array_ops.constant(3)
- with self.test_session():
+ with self.cached_session():
getter = model_utils.parameter_switch({overridden_parameter: 4})
self.assertEqual(5, getter(parameter))
self.assertEqual(4, getter(overridden_parameter))
diff --git a/tensorflow/contrib/timeseries/python/timeseries/state_management_test.py b/tensorflow/contrib/timeseries/python/timeseries/state_management_test.py
index d5dce30fda..42ba6e1c25 100644
--- a/tensorflow/contrib/timeseries/python/timeseries/state_management_test.py
+++ b/tensorflow/contrib/timeseries/python/timeseries/state_management_test.py
@@ -78,7 +78,7 @@ class StubTimeSeriesModel(model.TimeSeriesModel):
batch_end_values = array_ops.squeeze(
array_ops.slice(values, [0, array_ops.shape(times)[1] - 1, 0],
[-1, 1, -1]),
- squeeze_dims=[1, 2])
+ axis=[1, 2])
# A pretty odd but easy to think about loss: L1 loss on the batch end
# values.
loss = math_ops.reduce_sum(
@@ -127,7 +127,7 @@ class ChainingStateManagerTest(test.TestCase):
chainer.initialize_graph(model=stub_model)
model_outputs = chainer.define_loss(
model=stub_model, features=features, mode=estimator_lib.ModeKeys.TRAIN)
- with self.test_session() as session:
+ with self.cached_session() as session:
variables.global_variables_initializer().run()
coordinator = coordinator_lib.Coordinator()
queue_runner_impl.start_queue_runners(session, coord=coordinator)
@@ -178,7 +178,7 @@ class ChainingStateManagerTest(test.TestCase):
result_model_outputs = chainer.define_loss(
model=stub_model, features=result_input_fn()[0],
mode=estimator_lib.ModeKeys.TRAIN)
- with self.test_session() as session:
+ with self.cached_session() as session:
variables.global_variables_initializer().run()
coordinator = coordinator_lib.Coordinator()
queue_runner_impl.start_queue_runners(session, coord=coordinator)
@@ -221,7 +221,7 @@ class ChainingStateManagerTest(test.TestCase):
chainer.initialize_graph(model=stub_model)
model_outputs = chainer.define_loss(
model=stub_model, features=features, mode=estimator_lib.ModeKeys.TRAIN)
- with self.test_session() as session:
+ with self.cached_session() as session:
variables.global_variables_initializer().run()
coordinator = coordinator_lib.Coordinator()
queue_runner_impl.start_queue_runners(session, coord=coordinator)
diff --git a/tensorflow/contrib/timeseries/python/timeseries/state_space_models/BUILD b/tensorflow/contrib/timeseries/python/timeseries/state_space_models/BUILD
index 5d33e23a42..3c07a74ed8 100644
--- a/tensorflow/contrib/timeseries/python/timeseries/state_space_models/BUILD
+++ b/tensorflow/contrib/timeseries/python/timeseries/state_space_models/BUILD
@@ -176,8 +176,9 @@ py_library(
py_test(
name = "structural_ensemble_test",
- timeout = "long", # Moderate but for asan/tsan timeouts
+ timeout = "long", # Moderate but for asan/tsan/msan timeouts
srcs = ["structural_ensemble_test.py"],
+ shard_count = 4,
srcs_version = "PY2AND3",
deps = [
":state_space_model",
diff --git a/tensorflow/contrib/timeseries/python/timeseries/state_space_models/filtering_postprocessor_test.py b/tensorflow/contrib/timeseries/python/timeseries/state_space_models/filtering_postprocessor_test.py
index 53d7340e85..a77c507d9b 100644
--- a/tensorflow/contrib/timeseries/python/timeseries/state_space_models/filtering_postprocessor_test.py
+++ b/tensorflow/contrib/timeseries/python/timeseries/state_space_models/filtering_postprocessor_test.py
@@ -61,7 +61,7 @@ class FilteringStepPostprocessorTest(test.TestCase):
expected_state = [[[80.], [20.]],
[1., 6.],
[-1, -2]]
- with self.test_session():
+ with self.cached_session():
for interpolated, expected in zip(interpolated_state, expected_state):
self.assertAllClose(expected, interpolated.eval())
self.assertGreater(0., updated_outputs["anomaly_score"][0].eval())
diff --git a/tensorflow/contrib/timeseries/python/timeseries/state_space_models/kalman_filter.py b/tensorflow/contrib/timeseries/python/timeseries/state_space_models/kalman_filter.py
index 1fcd3e391b..a614386121 100644
--- a/tensorflow/contrib/timeseries/python/timeseries/state_space_models/kalman_filter.py
+++ b/tensorflow/contrib/timeseries/python/timeseries/state_space_models/kalman_filter.py
@@ -170,7 +170,7 @@ class KalmanFilter(object):
math_ops.matmul(
transition_matrices,
prior_state[..., None]),
- squeeze_dims=[-1])
+ axis=[-1])
return advanced_state
def predict_state_var(
@@ -254,7 +254,7 @@ class KalmanFilter(object):
kalman_gain_transposed,
array_ops.expand_dims(residual, -1),
adjoint_a=True),
- squeeze_dims=[-1])
+ axis=[-1])
gain_obs = math_ops.matmul(
kalman_gain_transposed, observation_model, adjoint_a=True)
identity_extradim = linalg_ops.eye(
@@ -332,7 +332,7 @@ class KalmanFilter(object):
array_ops.expand_dims(state_mean, 1),
observation_model,
adjoint_b=True),
- squeeze_dims=[1])
+ axis=[1])
observed_var = math_ops.matmul(
math_ops.matmul(observation_model, state_var),
observation_model,
diff --git a/tensorflow/contrib/timeseries/python/timeseries/state_space_models/kalman_filter_test.py b/tensorflow/contrib/timeseries/python/timeseries/state_space_models/kalman_filter_test.py
index 57f29f3c7f..f636126a33 100644
--- a/tensorflow/contrib/timeseries/python/timeseries/state_space_models/kalman_filter_test.py
+++ b/tensorflow/contrib/timeseries/python/timeseries/state_space_models/kalman_filter_test.py
@@ -98,7 +98,7 @@ class MultivariateTests(test.TestCase):
observation_model=observation_model,
predicted_observations=(observed_mean, observed_var),
observation_noise=observation_noise_covariance)
- with self.test_session() as session:
+ with self.cached_session() as session:
evaled_state = numpy.array([[1., 1., 1., 1.]])
evaled_state_var = numpy.eye(4)[None]
for i in range(500):
@@ -136,7 +136,7 @@ class KalmanFilterNonBatchTest(test.TestCase):
def test_observed_from_state(self):
"""Compare observation mean and noise to hand-computed values."""
- with self.test_session():
+ with self.cached_session():
state = constant_op.constant([[2., 1.]])
state_var = constant_op.constant([[[4., 0.], [0., 3.]]])
observed_mean, observed_var = self.kalman_filter.observed_from_state(
@@ -171,7 +171,7 @@ class KalmanFilterNonBatchTest(test.TestCase):
observation_model=observation_model,
predicted_observations=predicted_observations,
observation_noise=observation_noise))
- with self.test_session() as session:
+ with self.cached_session() as session:
evaled_state, evaled_state_var = session.run([state, state_var])
for _ in range(300):
evaled_state, evaled_state_var = session.run(
@@ -231,7 +231,7 @@ class KalmanFilterNonBatchTest(test.TestCase):
def test_predict_state_mean(self):
"""Compare state mean transitions with simple hand-computed values."""
- with self.test_session():
+ with self.cached_session():
state = constant_op.constant([[4., 2.]])
state = self.kalman_filter.predict_state_mean(
state, self.transition_fn([1]))
@@ -245,7 +245,7 @@ class KalmanFilterNonBatchTest(test.TestCase):
def test_predict_state_var(self):
"""Compare a variance transition with simple hand-computed values."""
- with self.test_session():
+ with self.cached_session():
state_var = constant_op.constant([[[1., 0.], [0., 2.]]])
state_var = self.kalman_filter.predict_state_var(
state_var, self.transition_fn([1]), self.power_sum_fn([1]))
@@ -259,7 +259,7 @@ class KalmanFilterNonBatchTest(test.TestCase):
Tests that correct values have high probability and incorrect values
have low probability when there is low uncertainty.
"""
- with self.test_session():
+ with self.cached_session():
state = constant_op.constant([[4., 2.]])
state_var = constant_op.constant([[[0.0001, 0.], [0., 0.0001]]])
observation = constant_op.constant([[
@@ -289,7 +289,7 @@ class KalmanFilterNonBatchTest(test.TestCase):
self.assertGreater(first_log_prob.eval()[0], numpy.log(0.99))
def test_predict_n_ahead_mean(self):
- with self.test_session():
+ with self.cached_session():
original_state = constant_op.constant([[4., 2.]])
n = 5
iterative_state = original_state
@@ -304,7 +304,7 @@ class KalmanFilterNonBatchTest(test.TestCase):
self.transition_fn([1]))
def test_predict_n_ahead_var(self):
- with self.test_session():
+ with self.cached_session():
original_var = constant_op.constant([[[2., 3.], [4., 5.]]])
n = 5
iterative_var = original_var
@@ -330,7 +330,7 @@ class KalmanFilterBatchTest(test.TestCase):
Tests that correct values have high probability and incorrect values
have low probability when there is low uncertainty.
"""
- with self.test_session():
+ with self.cached_session():
state = constant_op.constant([[4., 2.], [5., 3.], [6., 4.]])
state_var = constant_op.constant(3 * [[[0.0001, 0.], [0., 0.0001]]])
observation = constant_op.constant([
@@ -378,7 +378,7 @@ class KalmanFilterBatchTest(test.TestCase):
self.assertLess(third_log_prob.sum(), numpy.log(0.01))
def test_predict_n_ahead_mean(self):
- with self.test_session():
+ with self.cached_session():
kf = kalman_filter.KalmanFilter()
transition_fn, _ = _powers_and_sums_from_transition_matrix(
state_transition=STATE_TRANSITION,
@@ -396,7 +396,7 @@ class KalmanFilterBatchTest(test.TestCase):
self.assertAllClose(state2.eval()[2], batch_eval[2])
def test_predict_n_ahead_var(self):
- with self.test_session():
+ with self.cached_session():
kf = kalman_filter.KalmanFilter()
transition_fn, power_sum_fn = _powers_and_sums_from_transition_matrix(
state_transition=STATE_TRANSITION,
diff --git a/tensorflow/contrib/timeseries/python/timeseries/state_space_models/state_space_model_test.py b/tensorflow/contrib/timeseries/python/timeseries/state_space_models/state_space_model_test.py
index 1fb4a3c121..80126ac786 100644
--- a/tensorflow/contrib/timeseries/python/timeseries/state_space_models/state_space_model_test.py
+++ b/tensorflow/contrib/timeseries/python/timeseries/state_space_models/state_space_model_test.py
@@ -96,7 +96,7 @@ class ConstructionTests(test.TestCase):
},
mode=estimator_lib.ModeKeys.TRAIN)
initializer = variables.global_variables_initializer()
- with self.test_session() as sess:
+ with self.cached_session() as sess:
sess.run([initializer])
outputs.loss.eval()
@@ -114,7 +114,7 @@ class ConstructionTests(test.TestCase):
},
mode=estimator_lib.ModeKeys.TRAIN)
initializer = variables.global_variables_initializer()
- with self.test_session() as sess:
+ with self.cached_session() as sess:
sess.run([initializer])
outputs.loss.eval()
@@ -144,7 +144,7 @@ class GapTests(test.TestCase):
state=math_utils.replicate_state(
start_state=random_model.get_start_state(),
batch_size=array_ops.shape(times)[0]))
- with self.test_session() as session:
+ with self.cached_session() as session:
variables.global_variables_initializer().run()
coordinator = coordinator_lib.Coordinator()
queue_runner_impl.start_queue_runners(session, coord=coordinator)
@@ -190,13 +190,13 @@ class StateSpaceEquivalenceTests(test.TestCase):
estimator.build_raw_serving_input_receiver_fn())
with ops.Graph().as_default() as graph:
random_model.initialize_graph()
- with self.test_session(graph=graph) as session:
+ with self.session(graph=graph) as session:
variables.global_variables_initializer().run()
evaled_start_state = session.run(random_model.get_start_state())
evaled_start_state = [
state_element[None, ...] for state_element in evaled_start_state]
with ops.Graph().as_default() as graph:
- with self.test_session(graph=graph) as session:
+ with self.session(graph=graph) as session:
signatures = loader.load(
session, [tag_constants.SERVING], export_location)
first_split_filtering = saved_model_utils.filter_continuation(
@@ -250,7 +250,7 @@ class StateSpaceEquivalenceTests(test.TestCase):
self.assertAllClose(combined_value, split_predict[prediction_key])
def _equivalent_to_single_model_test_template(self, model_generator):
- with self.test_session() as session:
+ with self.cached_session() as session:
random_model = RandomStateSpaceModel(
state_dimension=5,
state_noise_dimension=4,
@@ -374,7 +374,7 @@ class PredictionTests(test.TestCase):
math_utils.replicate_state(
start_state=random_model.get_start_state(), batch_size=1)
})
- with self.test_session():
+ with self.cached_session():
variables.global_variables_initializer().run()
predicted_mean = prediction_dict["mean"].eval()
predicted_covariance = prediction_dict["covariance"].eval()
@@ -404,7 +404,7 @@ class PredictionTests(test.TestCase):
feature_keys.PredictionFeatures.TIMES: [[5, 7, 8]],
feature_keys.PredictionFeatures.STATE_TUPLE: model_outputs.end_state
})
- with self.test_session():
+ with self.cached_session():
variables.global_variables_initializer().run()
predicted_mean = predictions["mean"].eval()
predicted_covariance = predictions["covariance"].eval()
@@ -428,7 +428,7 @@ class ExogenousTests(test.TestCase):
state=[
array_ops.ones(shape=[1, 5]), original_covariance[None], [0]
])
- with self.test_session() as session:
+ with self.cached_session() as session:
variables.global_variables_initializer().run()
evaled_new_covariance, evaled_original_covariance = session.run(
[new_covariance[0], original_covariance])
@@ -454,7 +454,7 @@ class ExogenousTests(test.TestCase):
-array_ops.ones(shape=[1, 5], dtype=dtype),
original_covariance[None], [0]
])
- with self.test_session() as session:
+ with self.cached_session() as session:
variables.global_variables_initializer().run()
evaled_new_covariance, evaled_original_covariance = session.run(
[new_covariance[0], original_covariance])
@@ -519,7 +519,7 @@ class PosteriorTests(test.TestCase):
model=stub_model, data=data, true_parameters=true_params)
def test_exact_posterior_recovery_no_transition_noise(self):
- with self.test_session() as session:
+ with self.cached_session() as session:
stub_model, data, true_params = self._get_single_model()
input_fn = input_pipeline.WholeDatasetInputFn(
input_pipeline.NumpyReader(data))
@@ -559,7 +559,7 @@ class PosteriorTests(test.TestCase):
posterior_times)
def test_chained_exact_posterior_recovery_no_transition_noise(self):
- with self.test_session() as session:
+ with self.cached_session() as session:
stub_model, data, true_params = self._get_single_model()
chunk_size = 10
input_fn = test_utils.AllWindowInputFn(
@@ -748,7 +748,7 @@ class MultivariateTests(test.TestCase):
},
mode=estimator_lib.ModeKeys.TRAIN)
initializer = variables.global_variables_initializer()
- with self.test_session() as sess:
+ with self.cached_session() as sess:
sess.run([initializer])
outputs.loss.eval()
diff --git a/tensorflow/contrib/timeseries/python/timeseries/state_space_models/varma_test.py b/tensorflow/contrib/timeseries/python/timeseries/state_space_models/varma_test.py
index 84885d5c9a..e8875f4eb9 100644
--- a/tensorflow/contrib/timeseries/python/timeseries/state_space_models/varma_test.py
+++ b/tensorflow/contrib/timeseries/python/timeseries/state_space_models/varma_test.py
@@ -46,7 +46,7 @@ class MakeModelTest(test.TestCase):
},
mode=estimator_lib.ModeKeys.TRAIN)
initializer = variables.global_variables_initializer()
- with self.test_session() as sess:
+ with self.cached_session() as sess:
sess.run([initializer])
outputs.loss.eval()
@@ -65,7 +65,7 @@ class MakeModelTest(test.TestCase):
},
mode=estimator_lib.ModeKeys.TRAIN)
initializer = variables.global_variables_initializer()
- with self.test_session() as sess:
+ with self.cached_session() as sess:
sess.run([initializer])
outputs.loss.eval()
@@ -85,7 +85,7 @@ class MakeModelTest(test.TestCase):
TrainEvalFeatures.VALUES: constant_op.constant([[[1.], [2.]]])},
mode=estimator_lib.ModeKeys.TRAIN)
initializer = variables.global_variables_initializer()
- with self.test_session() as sess:
+ with self.cached_session() as sess:
sess.run([initializer])
outputs.loss.eval()