aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/contrib/timeseries
diff options
context:
space:
mode:
authorGravatar Allen Lavoie <allenl@google.com>2018-03-27 15:55:04 -0700
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2018-03-27 15:57:43 -0700
commita16761483ec55095158b1b11118d93ea00a538f4 (patch)
tree1d248134f08e8c8cc9c3602c0db1a3faa7fdb6b5 /tensorflow/contrib/timeseries
parentb4742b76c386409c96c60172e6ca1c1534e2b4af (diff)
TFTS: Fix a bug in the SavedModel cold-start export
It now correctly broadcasts start state across whatever batch dimension it is passed rather than sqishing it down to a batch dimension of 1. PiperOrigin-RevId: 190688855
Diffstat (limited to 'tensorflow/contrib/timeseries')
-rw-r--r--tensorflow/contrib/timeseries/python/timeseries/estimators_test.py21
-rw-r--r--tensorflow/contrib/timeseries/python/timeseries/head.py6
2 files changed, 25 insertions, 2 deletions
diff --git a/tensorflow/contrib/timeseries/python/timeseries/estimators_test.py b/tensorflow/contrib/timeseries/python/timeseries/estimators_test.py
index f4304f2560..51d0c0ca3f 100644
--- a/tensorflow/contrib/timeseries/python/timeseries/estimators_test.py
+++ b/tensorflow/contrib/timeseries/python/timeseries/estimators_test.py
@@ -126,6 +126,27 @@ class TimeSeriesRegressorTest(test.TestCase):
signatures=signatures,
session=sess)
+ # Test cold starting
+ batch_numpy_times = numpy.tile(
+ numpy.arange(30, dtype=numpy.int64)[None, :], (10, 1))
+ batch_numpy_values = numpy.ones([10, 30, 1])
+ state = saved_model_utils.cold_start_filter(
+ signatures=signatures,
+ session=sess,
+ features={
+ feature_keys.FilteringFeatures.TIMES: batch_numpy_times,
+ feature_keys.FilteringFeatures.VALUES: batch_numpy_values
+ }
+ )
+ predict_times = numpy.tile(
+ numpy.arange(30, 45, dtype=numpy.int64)[None, :], (10, 1))
+ predictions = saved_model_utils.predict_continuation(
+ continue_from=state,
+ times=predict_times,
+ signatures=signatures,
+ session=sess)
+ self.assertAllEqual([10, 15, 1], predictions["mean"].shape)
+
def test_fit_restore_fit_ar_regressor(self):
def _estimator_fn(model_dir):
return estimators.ARRegressor(
diff --git a/tensorflow/contrib/timeseries/python/timeseries/head.py b/tensorflow/contrib/timeseries/python/timeseries/head.py
index 3d7e615290..4cf6bbcfd4 100644
--- a/tensorflow/contrib/timeseries/python/timeseries/head.py
+++ b/tensorflow/contrib/timeseries/python/timeseries/head.py
@@ -154,8 +154,10 @@ class _TimeSeriesRegressionHead(head_lib._Head): # pylint:disable=protected-acc
no_state_features = {
k: v for k, v in features.items()
if not k.startswith(feature_keys.State.STATE_PREFIX)}
- cold_filtering_outputs = self.create_loss(
- no_state_features, estimator_lib.ModeKeys.EVAL)
+ # Ignore any state management when cold-starting. The model's default
+ # start state is replicated across the batch.
+ cold_filtering_outputs = self.model.define_loss(
+ features=no_state_features, mode=estimator_lib.ModeKeys.EVAL)
return estimator_lib.EstimatorSpec(
mode=estimator_lib.ModeKeys.PREDICT,
export_outputs={