aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/contrib/timeseries
diff options
context:
space:
mode:
authorGravatar Allen Lavoie <allenl@google.com>2018-08-01 14:05:11 -0700
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2018-08-01 14:09:01 -0700
commit392a6772f9895b52dd9aa69507c165a71a75547e (patch)
treeb0f8f8833bc4308db0ac5f6fc85af615baf23c03 /tensorflow/contrib/timeseries
parent2087cc7e33038d54181c97470d9d21a86024857c (diff)
TFTS: Add a parsing serving_input_receiver_fn for tf.Example protos
Works with the one-shot head (no model state in the tf.Example proto). PiperOrigin-RevId: 206988925
Diffstat (limited to 'tensorflow/contrib/timeseries')
-rw-r--r--tensorflow/contrib/timeseries/__init__.py3
-rw-r--r--tensorflow/contrib/timeseries/python/timeseries/BUILD1
-rw-r--r--tensorflow/contrib/timeseries/python/timeseries/__init__.py1
-rw-r--r--tensorflow/contrib/timeseries/python/timeseries/estimators.py168
-rw-r--r--tensorflow/contrib/timeseries/python/timeseries/head.py81
-rw-r--r--tensorflow/contrib/timeseries/python/timeseries/head_test.py60
6 files changed, 254 insertions, 60 deletions
diff --git a/tensorflow/contrib/timeseries/__init__.py b/tensorflow/contrib/timeseries/__init__.py
index 11db56b1b7..654a4db098 100644
--- a/tensorflow/contrib/timeseries/__init__.py
+++ b/tensorflow/contrib/timeseries/__init__.py
@@ -27,6 +27,9 @@
@@TrainEvalFeatures
@@FilteringResults
+
+@@TimeSeriesRegressor
+@@OneShotPredictionHead
"""
from __future__ import absolute_import
diff --git a/tensorflow/contrib/timeseries/python/timeseries/BUILD b/tensorflow/contrib/timeseries/python/timeseries/BUILD
index 7020989d68..0e96c1fbd4 100644
--- a/tensorflow/contrib/timeseries/python/timeseries/BUILD
+++ b/tensorflow/contrib/timeseries/python/timeseries/BUILD
@@ -161,6 +161,7 @@ py_test(
srcs = [
"head_test.py",
],
+ shard_count = 4,
srcs_version = "PY2AND3",
tags = ["no_pip_gpu"], # b/63391119
deps = [
diff --git a/tensorflow/contrib/timeseries/python/timeseries/__init__.py b/tensorflow/contrib/timeseries/python/timeseries/__init__.py
index c683dad71d..8462138339 100644
--- a/tensorflow/contrib/timeseries/python/timeseries/__init__.py
+++ b/tensorflow/contrib/timeseries/python/timeseries/__init__.py
@@ -24,5 +24,6 @@ from tensorflow.contrib.timeseries.python.timeseries import saved_model_utils
from tensorflow.contrib.timeseries.python.timeseries.ar_model import *
from tensorflow.contrib.timeseries.python.timeseries.estimators import *
from tensorflow.contrib.timeseries.python.timeseries.feature_keys import *
+from tensorflow.contrib.timeseries.python.timeseries.head import *
from tensorflow.contrib.timeseries.python.timeseries.input_pipeline import *
# pylint: enable=wildcard-import
diff --git a/tensorflow/contrib/timeseries/python/timeseries/estimators.py b/tensorflow/contrib/timeseries/python/timeseries/estimators.py
index 769183f40a..0ddc4b4144 100644
--- a/tensorflow/contrib/timeseries/python/timeseries/estimators.py
+++ b/tensorflow/contrib/timeseries/python/timeseries/estimators.py
@@ -37,6 +37,7 @@ from tensorflow.python.framework import ops
from tensorflow.python.framework import tensor_shape
from tensorflow.python.framework import tensor_util
from tensorflow.python.ops import array_ops
+from tensorflow.python.ops import math_ops
from tensorflow.python.ops import parsing_ops
from tensorflow.python.training import training as train
from tensorflow.python.util import nest
@@ -79,12 +80,137 @@ class TimeSeriesRegressor(estimator_lib.Estimator):
model_dir=model_dir,
config=config)
- # TODO(allenl): A parsing input receiver function, which takes a serialized
- # tf.Example containing all features (times, values, any exogenous features)
- # and serialized model state (possibly also as a tf.Example).
- def build_raw_serving_input_receiver_fn(self,
- default_batch_size=None,
- default_series_length=None):
+ def _model_start_state_placeholders(
+ self, batch_size_tensor, static_batch_size=None):
+ """Creates placeholders with zeroed start state for the current model."""
+ gathered_state = {}
+ # Models may not know the shape of their state without creating some
+ # variables/ops. Avoid polluting the default graph by making a new one. We
+ # use only static metadata from the returned Tensors.
+ with ops.Graph().as_default():
+ self._model.initialize_graph()
+ # Evaluate the initial state as same-dtype "zero" values. These zero
+ # constants aren't used, but are necessary for feeding to
+ # placeholder_with_default for the "cold start" case where state is not
+ # fed to the model.
+ def _zeros_like_constant(tensor):
+ return tensor_util.constant_value(array_ops.zeros_like(tensor))
+ start_state = nest.map_structure(
+ _zeros_like_constant, self._model.get_start_state())
+ for prefixed_state_name, state in ts_head_lib.state_to_dictionary(
+ start_state).items():
+ state_shape_with_batch = tensor_shape.TensorShape(
+ (static_batch_size,)).concatenate(state.shape)
+ default_state_broadcast = array_ops.tile(
+ state[None, ...],
+ multiples=array_ops.concat(
+ [batch_size_tensor[None],
+ array_ops.ones(len(state.shape), dtype=dtypes.int32)],
+ axis=0))
+ gathered_state[prefixed_state_name] = array_ops.placeholder_with_default(
+ input=default_state_broadcast,
+ name=prefixed_state_name,
+ shape=state_shape_with_batch)
+ return gathered_state
+
+ def build_one_shot_parsing_serving_input_receiver_fn(
+ self, filtering_length, prediction_length, default_batch_size=None,
+ values_input_dtype=None, truncate_values=False):
+ """Build an input_receiver_fn for export_savedmodel accepting tf.Examples.
+
+ Only compatible with `OneShotPredictionHead` (see `head`).
+
+ Args:
+ filtering_length: The number of time steps used as input to the model, for
+ which values are provided. If more than `filtering_length` values are
+ provided (via `truncate_values`), only the first `filtering_length`
+ values are used.
+ prediction_length: The number of time steps requested as predictions from
+ the model. Times and all exogenous features must be provided for these
+ steps.
+ default_batch_size: If specified, must be a scalar integer. Sets the batch
+ size in the static shape information of all feature Tensors, which means
+ only this batch size will be accepted by the exported model. If None
+ (default), static shape information for batch sizes is omitted.
+ values_input_dtype: An optional dtype specification for values in the
+ tf.Example protos (either float32 or int64, since these are the numeric
+ types supported by tf.Example). After parsing, values are cast to the
+ model's dtype (float32 or float64).
+ truncate_values: If True, expects `filtering_length + prediction_length`
+ values to be provided, but only uses the first `filtering_length`. If
+ False (default), exactly `filtering_length` values must be provided.
+
+ Returns:
+ An input_receiver_fn which may be passed to the Estimator's
+ export_savedmodel.
+
+ Expects features contained in a vector of serialized tf.Examples with
+ shape [batch size] (dtype `tf.string`), each tf.Example containing
+ features with the following shapes:
+ times: [filtering_length + prediction_length] integer
+ values: [filtering_length, num features] floating point. If
+ `truncate_values` is True, expects `filtering_length +
+ prediction_length` values but only uses the first `filtering_length`.
+ all exogenous features: [filtering_length + prediction_length, ...]
+ (various dtypes)
+ """
+ if values_input_dtype is None:
+ values_input_dtype = dtypes.float32
+ if truncate_values:
+ values_proto_length = filtering_length + prediction_length
+ else:
+ values_proto_length = filtering_length
+
+ def _serving_input_receiver_fn():
+ """A receiver function to be passed to export_savedmodel."""
+ times_column = feature_column.numeric_column(
+ key=feature_keys.TrainEvalFeatures.TIMES, dtype=dtypes.int64)
+ values_column = feature_column.numeric_column(
+ key=feature_keys.TrainEvalFeatures.VALUES, dtype=values_input_dtype,
+ shape=(self._model.num_features,))
+ parsed_features_no_sequence = (
+ feature_column.make_parse_example_spec(
+ list(self._model.exogenous_feature_columns)
+ + [times_column, values_column]))
+ parsed_features = {}
+ for key, feature_spec in parsed_features_no_sequence.items():
+ if isinstance(feature_spec, parsing_ops.FixedLenFeature):
+ if key == feature_keys.TrainEvalFeatures.VALUES:
+ parsed_features[key] = feature_spec._replace(
+ shape=((values_proto_length,)
+ + feature_spec.shape))
+ else:
+ parsed_features[key] = feature_spec._replace(
+ shape=((filtering_length + prediction_length,)
+ + feature_spec.shape))
+ elif feature_spec.dtype == dtypes.string:
+ parsed_features[key] = parsing_ops.FixedLenFeature(
+ shape=(filtering_length + prediction_length,),
+ dtype=dtypes.string)
+ else: # VarLenFeature
+ raise ValueError("VarLenFeatures not supported, got %s for key %s"
+ % (feature_spec, key))
+ tfexamples = array_ops.placeholder(
+ shape=[default_batch_size], dtype=dtypes.string, name="input")
+ features = parsing_ops.parse_example(
+ serialized=tfexamples,
+ features=parsed_features)
+ features[feature_keys.TrainEvalFeatures.TIMES] = array_ops.squeeze(
+ features[feature_keys.TrainEvalFeatures.TIMES], axis=-1)
+ features[feature_keys.TrainEvalFeatures.VALUES] = math_ops.cast(
+ features[feature_keys.TrainEvalFeatures.VALUES],
+ dtype=self._model.dtype)[:, :filtering_length]
+ features.update(
+ self._model_start_state_placeholders(
+ batch_size_tensor=array_ops.shape(
+ features[feature_keys.TrainEvalFeatures.TIMES])[0],
+ static_batch_size=default_batch_size))
+ return export_lib.ServingInputReceiver(
+ features, {"examples": tfexamples})
+ return _serving_input_receiver_fn
+
+ def build_raw_serving_input_receiver_fn(
+ self, default_batch_size=None, default_series_length=None):
"""Build an input_receiver_fn for export_savedmodel which accepts arrays.
Automatically creates placeholders for exogenous `FeatureColumn`s passed to
@@ -149,34 +275,10 @@ class TimeSeriesRegressor(estimator_lib.Estimator):
+ batch_only_feature_shape[1:])
placeholders[feature_key] = array_ops.placeholder(
dtype=value_dtype, name=feature_key, shape=feature_shape)
- # Models may not know the shape of their state without creating some
- # variables/ops. Avoid polluting the default graph by making a new one. We
- # use only static metadata from the returned Tensors.
- with ops.Graph().as_default():
- self._model.initialize_graph()
- # Evaluate the initial state as same-dtype "zero" values. These zero
- # constants aren't used, but are necessary for feeding to
- # placeholder_with_default for the "cold start" case where state is not
- # fed to the model.
- def _zeros_like_constant(tensor):
- return tensor_util.constant_value(array_ops.zeros_like(tensor))
- start_state = nest.map_structure(
- _zeros_like_constant, self._model.get_start_state())
batch_size_tensor = array_ops.shape(time_placeholder)[0]
- for prefixed_state_name, state in ts_head_lib.state_to_dictionary(
- start_state).items():
- state_shape_with_batch = tensor_shape.TensorShape(
- (default_batch_size,)).concatenate(state.shape)
- default_state_broadcast = array_ops.tile(
- state[None, ...],
- multiples=array_ops.concat(
- [batch_size_tensor[None],
- array_ops.ones(len(state.shape), dtype=dtypes.int32)],
- axis=0))
- placeholders[prefixed_state_name] = array_ops.placeholder_with_default(
- input=default_state_broadcast,
- name=prefixed_state_name,
- shape=state_shape_with_batch)
+ placeholders.update(
+ self._model_start_state_placeholders(
+ batch_size_tensor, static_batch_size=default_batch_size))
return export_lib.ServingInputReceiver(placeholders, placeholders)
return _serving_input_receiver_fn
diff --git a/tensorflow/contrib/timeseries/python/timeseries/head.py b/tensorflow/contrib/timeseries/python/timeseries/head.py
index 8686a803e5..d2484d0ef5 100644
--- a/tensorflow/contrib/timeseries/python/timeseries/head.py
+++ b/tensorflow/contrib/timeseries/python/timeseries/head.py
@@ -26,6 +26,7 @@ from tensorflow.python.estimator.canned import metric_keys
from tensorflow.python.estimator.export import export_lib
from tensorflow.python.framework import dtypes
from tensorflow.python.framework import ops
+from tensorflow.python.framework import sparse_tensor
from tensorflow.python.ops import array_ops
from tensorflow.python.ops import control_flow_ops
from tensorflow.python.ops import math_ops
@@ -180,7 +181,7 @@ class TimeSeriesRegressionHead(head_lib._Head): # pylint:disable=protected-acce
return math_ops.cast(value, self.model.dtype)
if name == feature_keys.PredictionFeatures.STATE_TUPLE:
return value # Correct dtypes are model-dependent
- return ops.convert_to_tensor(value)
+ return sparse_tensor.convert_to_tensor_or_sparse_tensor(value)
def _gather_state(self, features):
"""Returns `features` with state packed, indicates if packing was done."""
@@ -202,6 +203,29 @@ class TimeSeriesRegressionHead(head_lib._Head): # pylint:disable=protected-acce
flat_sequence=[tensor for _, _, tensor in numbered_state])
return features, True
+ def _check_predict_features(self, features):
+ """Raises errors if features are not suitable for prediction."""
+ if feature_keys.PredictionFeatures.TIMES not in features:
+ raise ValueError("Expected a '{}' feature for prediction.".format(
+ feature_keys.PredictionFeatures.TIMES))
+ if feature_keys.PredictionFeatures.STATE_TUPLE not in features:
+ raise ValueError("Expected a '{}' feature for prediction.".format(
+ feature_keys.PredictionFeatures.STATE_TUPLE))
+ times_feature = features[feature_keys.PredictionFeatures.TIMES]
+ if not times_feature.get_shape().is_compatible_with([None, None]):
+ raise ValueError(
+ ("Expected shape (batch dimension, window size) for feature '{}' "
+ "(got shape {})").format(feature_keys.PredictionFeatures.TIMES,
+ times_feature.get_shape()))
+ _check_feature_shapes_compatible_with(
+ features=features,
+ compatible_with_name=feature_keys.PredictionFeatures.TIMES,
+ compatible_with_value=times_feature,
+ ignore=set([
+ # Model-dependent shapes
+ feature_keys.PredictionFeatures.STATE_TUPLE
+ ]))
+
def create_estimator_spec(self, features, mode, labels=None):
"""Performs basic error checking and returns an EstimatorSpec."""
with ops.name_scope(self._name, "head"):
@@ -230,7 +254,7 @@ class TimeSeriesRegressionHead(head_lib._Head): # pylint:disable=protected-acce
mode == estimator_lib.ModeKeys.EVAL):
_check_train_eval_features(features, self.model)
elif mode == estimator_lib.ModeKeys.PREDICT:
- _check_predict_features(features)
+ self._check_predict_features(features)
else:
raise ValueError("Unknown mode '{}' passed to model_fn.".format(mode))
@@ -267,6 +291,36 @@ class OneShotPredictionHead(TimeSeriesRegressionHead):
each time predictions are requested when using this head.
"""
+ def _check_predict_features(self, features):
+ """Raises errors if features are not suitable for one-shot prediction."""
+ if feature_keys.PredictionFeatures.TIMES not in features:
+ raise ValueError("Expected a '{}' feature for prediction.".format(
+ feature_keys.PredictionFeatures.TIMES))
+ if feature_keys.TrainEvalFeatures.VALUES not in features:
+ raise ValueError("Expected a '{}' feature for prediction.".format(
+ feature_keys.TrainEvalFeatures.VALUES))
+ if feature_keys.PredictionFeatures.STATE_TUPLE not in features:
+ raise ValueError("Expected a '{}' feature for prediction.".format(
+ feature_keys.PredictionFeatures.STATE_TUPLE))
+ times_feature = features[feature_keys.PredictionFeatures.TIMES]
+ if not times_feature.get_shape().is_compatible_with([None, None]):
+ raise ValueError(
+ ("Expected shape (batch dimension, window size) for feature '{}' "
+ "(got shape {})").format(feature_keys.PredictionFeatures.TIMES,
+ times_feature.get_shape()))
+ _check_feature_shapes_compatible_with(
+ features=features,
+ compatible_with_name=feature_keys.PredictionFeatures.TIMES,
+ compatible_with_value=times_feature,
+ ignore=set([
+ # Model-dependent shapes
+ feature_keys.PredictionFeatures.STATE_TUPLE,
+ # One shot prediction head relies on values being shorter than
+ # times. Even though we're predicting eventually, we need values for
+ # the filtering phase.
+ feature_keys.TrainEvalFeatures.VALUES,
+ ]))
+
def _serving_ops(self, features):
"""Add ops for serving to the graph."""
with variable_scope.variable_scope("model", use_resource=True):
@@ -333,29 +387,6 @@ def _check_feature_shapes_compatible_with(features,
times_shape=compatible_with_value.get_shape()))
-def _check_predict_features(features):
- """Raises errors if features are not suitable for prediction."""
- if feature_keys.PredictionFeatures.TIMES not in features:
- raise ValueError("Expected a '{}' feature for prediction.".format(
- feature_keys.PredictionFeatures.TIMES))
- if feature_keys.PredictionFeatures.STATE_TUPLE not in features:
- raise ValueError("Expected a '{}' feature for prediction.".format(
- feature_keys.PredictionFeatures.STATE_TUPLE))
- times_feature = features[feature_keys.PredictionFeatures.TIMES]
- if not times_feature.get_shape().is_compatible_with([None, None]):
- raise ValueError(
- ("Expected shape (batch dimension, window size) for feature '{}' "
- "(got shape {})").format(feature_keys.PredictionFeatures.TIMES,
- times_feature.get_shape()))
- _check_feature_shapes_compatible_with(
- features=features,
- compatible_with_name=feature_keys.PredictionFeatures.TIMES,
- compatible_with_value=times_feature,
- ignore=set([
- feature_keys.PredictionFeatures.STATE_TUPLE # Model-dependent shapes
- ]))
-
-
def _check_train_eval_features(features, model):
"""Raise errors if features are not suitable for training/evaluation."""
if feature_keys.TrainEvalFeatures.TIMES not in features:
diff --git a/tensorflow/contrib/timeseries/python/timeseries/head_test.py b/tensorflow/contrib/timeseries/python/timeseries/head_test.py
index 78c2cec21c..857e7c5635 100644
--- a/tensorflow/contrib/timeseries/python/timeseries/head_test.py
+++ b/tensorflow/contrib/timeseries/python/timeseries/head_test.py
@@ -18,6 +18,7 @@ from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
+import functools
import os
from absl.testing import parameterized
@@ -26,12 +27,14 @@ import six
from tensorflow.contrib.estimator.python.estimator import extenders
from tensorflow.contrib.timeseries.examples import lstm as lstm_example
+from tensorflow.contrib.timeseries.python.timeseries import ar_model
from tensorflow.contrib.timeseries.python.timeseries import estimators as ts_estimators
from tensorflow.contrib.timeseries.python.timeseries import feature_keys
from tensorflow.contrib.timeseries.python.timeseries import head as ts_head_lib
from tensorflow.contrib.timeseries.python.timeseries import input_pipeline
from tensorflow.contrib.timeseries.python.timeseries import model
from tensorflow.contrib.timeseries.python.timeseries import state_management
+from tensorflow.core.example import example_pb2
from tensorflow.python.client import session as session_lib
from tensorflow.python.estimator import estimator_lib
@@ -343,15 +346,33 @@ def _structural_ensemble_regressor(
model_dir=model_dir)
+def _ar_lstm_regressor(
+ model_dir, head_type, exogenous_feature_columns):
+ return ts_estimators.TimeSeriesRegressor(
+ model=ar_model.ARModel(
+ periodicities=10, input_window_size=10, output_window_size=6,
+ num_features=5,
+ exogenous_feature_columns=exogenous_feature_columns,
+ prediction_model_factory=functools.partial(
+ ar_model.LSTMPredictionModel,
+ num_units=10)),
+ head_type=head_type,
+ model_dir=model_dir)
+
+
class OneShotTests(parameterized.TestCase):
@parameterized.named_parameters(
+ {"testcase_name": "ar_lstm_regressor",
+ "estimator_factory": _ar_lstm_regressor},
{"testcase_name": "custom_time_series_regressor",
"estimator_factory": _custom_time_series_regressor},
{"testcase_name": "structural_ensemble_regressor",
"estimator_factory": _structural_ensemble_regressor})
def test_one_shot_prediction_head_export(self, estimator_factory):
- model_dir = os.path.join(test.get_temp_dir(), str(ops.uid()))
+ def _new_temp_dir():
+ return os.path.join(test.get_temp_dir(), str(ops.uid()))
+ model_dir = _new_temp_dir()
categorical_column = feature_column.categorical_column_with_hash_bucket(
key="categorical_exogenous_feature", hash_bucket_size=16)
exogenous_feature_columns = [
@@ -377,7 +398,7 @@ class OneShotTests(parameterized.TestCase):
num_threads=1, batch_size=16, window_size=16)
estimator.train(input_fn=train_input_fn, steps=5)
input_receiver_fn = estimator.build_raw_serving_input_receiver_fn()
- export_location = estimator.export_savedmodel(test.get_temp_dir(),
+ export_location = estimator.export_savedmodel(_new_temp_dir(),
input_receiver_fn)
graph = ops.Graph()
with graph.as_default():
@@ -412,6 +433,41 @@ class OneShotTests(parameterized.TestCase):
in predict_signature.outputs.items()}
output = session.run(fetches, feed_dict=feeds)
self.assertEqual((2, 15, 5), output["mean"].shape)
+ # Build a parsing input function, then make a tf.Example for it to parse.
+ export_location = estimator.export_savedmodel(
+ _new_temp_dir(),
+ estimator.build_one_shot_parsing_serving_input_receiver_fn(
+ filtering_length=20, prediction_length=15))
+ graph = ops.Graph()
+ with graph.as_default():
+ with session_lib.Session() as session:
+ example = example_pb2.Example()
+ times = example.features.feature[feature_keys.TrainEvalFeatures.TIMES]
+ values = example.features.feature[feature_keys.TrainEvalFeatures.VALUES]
+ times.int64_list.value.extend(range(35))
+ for i in range(20):
+ values.float_list.value.extend(
+ [float(i) * 2. + feature_number
+ for feature_number in range(5)])
+ real_feature = example.features.feature["2d_exogenous_feature"]
+ categortical_feature = example.features.feature[
+ "categorical_exogenous_feature"]
+ for i in range(35):
+ real_feature.float_list.value.extend([1, 1])
+ categortical_feature.bytes_list.value.append(b"strkey")
+ # Serialize the tf.Example for feeding to the Session
+ examples = [example.SerializeToString()] * 2
+ signatures = loader.load(
+ session, [tag_constants.SERVING], export_location)
+ predict_signature = signatures.signature_def[
+ feature_keys.SavedModelLabels.PREDICT]
+ ((_, input_value),) = predict_signature.inputs.items()
+ feeds = {graph.as_graph_element(input_value.name): examples}
+ fetches = {output_key: graph.as_graph_element(output_value.name)
+ for output_key, output_value
+ in predict_signature.outputs.items()}
+ output = session.run(fetches, feed_dict=feeds)
+ self.assertEqual((2, 15, 5), output["mean"].shape)
if __name__ == "__main__":