aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/contrib/timeseries
diff options
context:
space:
mode:
authorGravatar Allen Lavoie <allenl@google.com>2018-04-02 17:21:41 -0700
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2018-04-02 17:24:23 -0700
commit03613455c1c1c3957aedc4edcedd96a21bf9a514 (patch)
treee1b2d722a48dda501dc59062f112461c492ab6cb /tensorflow/contrib/timeseries
parent7e9113ab912caff9ad15195b15771ff20bde6080 (diff)
TFTS: Add a OneShotPredictionHead with no model state in its serving signature.
PiperOrigin-RevId: 191373516
Diffstat (limited to 'tensorflow/contrib/timeseries')
-rw-r--r--tensorflow/contrib/timeseries/python/timeseries/BUILD15
-rw-r--r--tensorflow/contrib/timeseries/python/timeseries/estimators.py8
-rw-r--r--tensorflow/contrib/timeseries/python/timeseries/head.py84
-rw-r--r--tensorflow/contrib/timeseries/python/timeseries/head_test.py86
4 files changed, 166 insertions, 27 deletions
diff --git a/tensorflow/contrib/timeseries/python/timeseries/BUILD b/tensorflow/contrib/timeseries/python/timeseries/BUILD
index 55a25e39fe..86022f46ce 100644
--- a/tensorflow/contrib/timeseries/python/timeseries/BUILD
+++ b/tensorflow/contrib/timeseries/python/timeseries/BUILD
@@ -88,10 +88,14 @@ py_library(
"//tensorflow/python:array_ops",
"//tensorflow/python:dtypes",
"//tensorflow/python:framework_ops",
+ "//tensorflow/python:parsing_ops",
"//tensorflow/python:tensor_shape",
+ "//tensorflow/python:tensor_util",
"//tensorflow/python:training",
+ "//tensorflow/python:util",
"//tensorflow/python/estimator:estimator_py",
"//tensorflow/python/estimator:export",
+ "//tensorflow/python/feature_column",
],
)
@@ -132,7 +136,6 @@ py_library(
srcs_version = "PY2AND3",
deps = [
":feature_keys",
- "//tensorflow/contrib/framework:framework_py",
"//tensorflow/contrib/layers:layers_py",
"//tensorflow/python:array_ops",
"//tensorflow/python:control_flow_ops",
@@ -141,6 +144,7 @@ py_library(
"//tensorflow/python:math_ops",
"//tensorflow/python:state_ops",
"//tensorflow/python:summary",
+ "//tensorflow/python:training",
"//tensorflow/python:util",
"//tensorflow/python:variable_scope",
"//tensorflow/python/estimator:estimator_py",
@@ -160,19 +164,28 @@ py_test(
"no_pip_gpu", # b/63391119
],
deps = [
+ ":estimators",
":feature_keys",
":head",
+ ":input_pipeline",
":model",
":state_management",
+ "//tensorflow/contrib/timeseries/examples:lstm",
"//tensorflow/python:array_ops",
"//tensorflow/python:client_testlib",
"//tensorflow/python:dtypes",
"//tensorflow/python:framework_ops",
"//tensorflow/python:math_ops",
"//tensorflow/python:metrics",
+ "//tensorflow/python:session",
"//tensorflow/python:training",
"//tensorflow/python:variables",
"//tensorflow/python/estimator:estimator_py",
+ "//tensorflow/python/feature_column",
+ "//tensorflow/python/saved_model:loader",
+ "//tensorflow/python/saved_model:tag_constants",
+ "//third_party/py/numpy",
+ "@six_archive//:six",
],
)
diff --git a/tensorflow/contrib/timeseries/python/timeseries/estimators.py b/tensorflow/contrib/timeseries/python/timeseries/estimators.py
index 469cea4fd2..886e1846e2 100644
--- a/tensorflow/contrib/timeseries/python/timeseries/estimators.py
+++ b/tensorflow/contrib/timeseries/python/timeseries/estimators.py
@@ -44,7 +44,7 @@ class TimeSeriesRegressor(estimator_lib.Estimator):
"""An Estimator to fit and evaluate a time series model."""
def __init__(self, model, state_manager=None, optimizer=None, model_dir=None,
- config=None):
+ config=None, head_type=ts_head_lib.TimeSeriesRegressionHead):
"""Initialize the Estimator.
Args:
@@ -55,6 +55,8 @@ class TimeSeriesRegressor(estimator_lib.Estimator):
from tf.train.Optimizer. Defaults to Adam with step size 0.02.
model_dir: See `Estimator`.
config: See `Estimator`.
+ head_type: The kind of head to use for the model (inheriting from
+ `TimeSeriesRegressionHead`).
"""
input_statistics_generator = math_utils.InputStatisticsFromMiniBatch(
dtype=model.dtype, num_features=model.num_features)
@@ -63,8 +65,8 @@ class TimeSeriesRegressor(estimator_lib.Estimator):
if optimizer is None:
optimizer = train.AdamOptimizer(0.02)
self._model = model
- ts_regression_head = ts_head_lib.time_series_regression_head(
- model, state_manager, optimizer,
+ ts_regression_head = head_type(
+ model=model, state_manager=state_manager, optimizer=optimizer,
input_statistics_generator=input_statistics_generator)
model_fn = ts_regression_head.create_estimator_spec
super(TimeSeriesRegressor, self).__init__(
diff --git a/tensorflow/contrib/timeseries/python/timeseries/head.py b/tensorflow/contrib/timeseries/python/timeseries/head.py
index 71085f9de8..a28a5872b8 100644
--- a/tensorflow/contrib/timeseries/python/timeseries/head.py
+++ b/tensorflow/contrib/timeseries/python/timeseries/head.py
@@ -39,25 +39,6 @@ from tensorflow.python.util import nest
from tensorflow.python.summary import summary
-def time_series_regression_head(model,
- state_manager,
- optimizer,
- input_statistics_generator=None):
- """Creates a `_Head` for time series regression.
-
- Args:
- model: A model for time series regression.
- state_manager: A state manager.
- optimizer: An optimizer.
- input_statistics_generator: A input statistics generator.
-
- Returns:
- An instance of `_Head` for time series regression.
- """
- return _TimeSeriesRegressionHead(model, state_manager, optimizer,
- input_statistics_generator)
-
-
class _NoStatePredictOutput(export_lib.PredictOutput):
def as_signature_def(self, receiver_tensors):
@@ -68,8 +49,8 @@ class _NoStatePredictOutput(export_lib.PredictOutput):
receiver_tensors=no_state_receiver_tensors)
-class _TimeSeriesRegressionHead(head_lib._Head): # pylint:disable=protected-access
- """See `time_series_regression_head`."""
+class TimeSeriesRegressionHead(head_lib._Head): # pylint:disable=protected-access
+ """Determines input and output signatures for a time series model."""
def __init__(self,
model,
@@ -77,6 +58,15 @@ class _TimeSeriesRegressionHead(head_lib._Head): # pylint:disable=protected-acc
optimizer,
input_statistics_generator=None,
name=None):
+ """Creates a `_Head` for time series regression.
+
+ Args:
+ model: A model for time series regression.
+ state_manager: A state manager.
+ optimizer: An optimizer.
+ input_statistics_generator: A input statistics generator.
+ name: An optional name for the model.
+ """
self.model = model
self.state_manager = state_manager
self.optimizer = optimizer
@@ -265,6 +255,58 @@ class _TimeSeriesRegressionHead(head_lib._Head): # pylint:disable=protected-acc
return self._serving_ops(features)
+class OneShotPredictionHead(TimeSeriesRegressionHead):
+ """A time series head which exports a single stateless serving signature.
+
+ The serving default signature exported by this head expects `times`, `values`,
+ and any exogenous features, but no state. `values` has shape `[batch_size,
+ filter_length, num_features]` and `times` has shape `[batch_size,
+ total_length]`, where `total_length > filter_length`. Any exogenous features
+ must have their shapes prefixed by the shape of the `times` feature.
+
+ When serving, first performs filtering on the series up to `filter_length`
+ starting from the default start state for the model, then computes predictions
+ on the remainder of the series, returning them.
+
+ Model state is neither accepted nor returned, so filtering must be performed
+ each time predictions are requested when using this head.
+ """
+
+ def _serving_ops(self, features):
+ """Add ops for serving to the graph."""
+ with variable_scope.variable_scope("model", use_resource=True):
+ filtering_features = {}
+ prediction_features = {}
+ values_length = array_ops.shape(
+ features[feature_keys.FilteringFeatures.VALUES])[1]
+ for key, value in features.items():
+ if key == feature_keys.State.STATE_TUPLE:
+ # Ignore state input. The model's default start state is replicated
+ # across the batch.
+ continue
+ if key == feature_keys.FilteringFeatures.VALUES:
+ filtering_features[key] = value
+ else:
+ filtering_features[key] = value[:, :values_length]
+ prediction_features[key] = value[:, values_length:]
+ cold_filtering_outputs = self.model.define_loss(
+ features=filtering_features, mode=estimator_lib.ModeKeys.EVAL)
+ prediction_features[feature_keys.State.STATE_TUPLE] = (
+ cold_filtering_outputs.end_state)
+ with variable_scope.variable_scope("model", reuse=True):
+ prediction_outputs = self.model.predict(
+ features=prediction_features)
+ return estimator_lib.EstimatorSpec(
+ mode=estimator_lib.ModeKeys.PREDICT,
+ export_outputs={
+ feature_keys.SavedModelLabels.PREDICT:
+ _NoStatePredictOutput(prediction_outputs),
+ },
+ # Likely unused, but it is necessary to return `predictions` to satisfy
+ # the Estimator's error checking.
+ predictions={})
+
+
def _check_feature_shapes_compatible_with(features,
compatible_with_name,
compatible_with_value,
diff --git a/tensorflow/contrib/timeseries/python/timeseries/head_test.py b/tensorflow/contrib/timeseries/python/timeseries/head_test.py
index 3415061cfd..c606db76a6 100644
--- a/tensorflow/contrib/timeseries/python/timeseries/head_test.py
+++ b/tensorflow/contrib/timeseries/python/timeseries/head_test.py
@@ -18,12 +18,20 @@ from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
+import numpy
+import six
+
+from tensorflow.contrib.timeseries.examples import lstm as lstm_example
+from tensorflow.contrib.timeseries.python.timeseries import estimators as ts_estimators
from tensorflow.contrib.timeseries.python.timeseries import feature_keys
from tensorflow.contrib.timeseries.python.timeseries import head as ts_head_lib
+from tensorflow.contrib.timeseries.python.timeseries import input_pipeline
from tensorflow.contrib.timeseries.python.timeseries import model
from tensorflow.contrib.timeseries.python.timeseries import state_management
+from tensorflow.python.client import session as session_lib
from tensorflow.python.estimator import estimator_lib
+from tensorflow.python.feature_column import feature_column
from tensorflow.python.framework import dtypes
from tensorflow.python.framework import ops
from tensorflow.python.ops import array_ops
@@ -31,6 +39,9 @@ from tensorflow.python.ops import math_ops
from tensorflow.python.ops import metrics
from tensorflow.python.ops import variables
from tensorflow.python.platform import test
+from tensorflow.python.saved_model import loader
+from tensorflow.python.saved_model import tag_constants
+from tensorflow.python.training import adam
from tensorflow.python.training import coordinator as coordinator_lib
from tensorflow.python.training import queue_runner_impl
from tensorflow.python.training import training as train
@@ -90,7 +101,7 @@ class EvaluationMetricsTests(test.TestCase):
.count_up_to(10),
dtype=dtypes.float32), (1, 1, 1))
}
- model_fn = ts_head_lib.time_series_regression_head(
+ model_fn = ts_head_lib.TimeSeriesRegressionHead(
model=_TickerModel(),
state_manager=state_management.PassthroughStateManager(),
optimizer=train.GradientDescentOptimizer(0.001)).create_estimator_spec
@@ -127,7 +138,7 @@ class _StubModel(object):
def _stub_model_fn():
- return ts_head_lib.time_series_regression_head(
+ return ts_head_lib.TimeSeriesRegressionHead(
model=_StubModel(),
state_manager=state_management.PassthroughStateManager(),
optimizer=train.AdamOptimizer(0.001)).create_estimator_spec
@@ -263,5 +274,76 @@ class PredictFeatureCheckingTests(test.TestCase):
mode=estimator_lib.ModeKeys.PREDICT)
+class OneShotTests(test.TestCase):
+
+ def test_one_shot_prediction_head_export(self):
+ model_dir = self.get_temp_dir()
+ categorical_column = feature_column.categorical_column_with_hash_bucket(
+ key="categorical_exogenous_feature", hash_bucket_size=16)
+ exogenous_feature_columns = [
+ feature_column.numeric_column(
+ "2d_exogenous_feature", shape=(2,)),
+ feature_column.embedding_column(
+ categorical_column=categorical_column, dimension=10)]
+ estimator = ts_estimators.TimeSeriesRegressor(
+ model=lstm_example._LSTMModel(
+ num_features=5, num_units=128,
+ exogenous_feature_columns=exogenous_feature_columns),
+ optimizer=adam.AdamOptimizer(0.001),
+ config=estimator_lib.RunConfig(tf_random_seed=4),
+ state_manager=state_management.ChainingStateManager(),
+ head_type=ts_head_lib.OneShotPredictionHead,
+ model_dir=model_dir)
+ train_features = {
+ feature_keys.TrainEvalFeatures.TIMES: numpy.arange(
+ 20, dtype=numpy.int64),
+ feature_keys.TrainEvalFeatures.VALUES: numpy.tile(numpy.arange(
+ 20, dtype=numpy.float32)[:, None], [1, 5]),
+ "2d_exogenous_feature": numpy.ones([20, 2]),
+ "categorical_exogenous_feature": numpy.array(
+ ["strkey"] * 20)[:, None]
+ }
+ train_input_fn = input_pipeline.RandomWindowInputFn(
+ input_pipeline.NumpyReader(train_features), shuffle_seed=2,
+ num_threads=1, batch_size=16, window_size=16)
+ estimator.train(input_fn=train_input_fn, steps=5)
+ input_receiver_fn = estimator.build_raw_serving_input_receiver_fn()
+ export_location = estimator.export_savedmodel(self.get_temp_dir(),
+ input_receiver_fn)
+ graph = ops.Graph()
+ with graph.as_default():
+ with session_lib.Session() as session:
+ signatures = loader.load(
+ session, [tag_constants.SERVING], export_location)
+ self.assertEqual([feature_keys.SavedModelLabels.PREDICT],
+ list(signatures.signature_def.keys()))
+ predict_signature = signatures.signature_def[
+ feature_keys.SavedModelLabels.PREDICT]
+ six.assertCountEqual(
+ self,
+ [feature_keys.FilteringFeatures.TIMES,
+ feature_keys.FilteringFeatures.VALUES,
+ "2d_exogenous_feature",
+ "categorical_exogenous_feature"],
+ predict_signature.inputs.keys())
+ features = {
+ feature_keys.TrainEvalFeatures.TIMES: numpy.tile(
+ numpy.arange(35, dtype=numpy.int64)[None, :], [2, 1]),
+ feature_keys.TrainEvalFeatures.VALUES: numpy.tile(numpy.arange(
+ 20, dtype=numpy.float32)[None, :, None], [2, 1, 5]),
+ "2d_exogenous_feature": numpy.ones([2, 35, 2]),
+ "categorical_exogenous_feature": numpy.tile(numpy.array(
+ ["strkey"] * 35)[None, :, None], [2, 1, 1])
+ }
+ feeds = {
+ graph.as_graph_element(input_value.name): features[input_key]
+ for input_key, input_value in predict_signature.inputs.items()}
+ fetches = {output_key: graph.as_graph_element(output_value.name)
+ for output_key, output_value
+ in predict_signature.outputs.items()}
+ output = session.run(fetches, feed_dict=feeds)
+ self.assertAllEqual((2, 15, 5), output["mean"].shape)
+
+
if __name__ == "__main__":
test.main()