diff options
author | 2018-03-27 15:55:04 -0700 | |
---|---|---|
committer | 2018-03-27 15:57:43 -0700 | |
commit | a16761483ec55095158b1b11118d93ea00a538f4 (patch) | |
tree | 1d248134f08e8c8cc9c3602c0db1a3faa7fdb6b5 /tensorflow/contrib/timeseries | |
parent | b4742b76c386409c96c60172e6ca1c1534e2b4af (diff) |
TFTS: Fix a bug in the SavedModel cold-start export
It now correctly broadcasts start state across whatever batch dimension it is
passed rather than sqishing it down to a batch dimension of 1.
PiperOrigin-RevId: 190688855
Diffstat (limited to 'tensorflow/contrib/timeseries')
-rw-r--r-- | tensorflow/contrib/timeseries/python/timeseries/estimators_test.py | 21 | ||||
-rw-r--r-- | tensorflow/contrib/timeseries/python/timeseries/head.py | 6 |
2 files changed, 25 insertions, 2 deletions
diff --git a/tensorflow/contrib/timeseries/python/timeseries/estimators_test.py b/tensorflow/contrib/timeseries/python/timeseries/estimators_test.py index f4304f2560..51d0c0ca3f 100644 --- a/tensorflow/contrib/timeseries/python/timeseries/estimators_test.py +++ b/tensorflow/contrib/timeseries/python/timeseries/estimators_test.py @@ -126,6 +126,27 @@ class TimeSeriesRegressorTest(test.TestCase): signatures=signatures, session=sess) + # Test cold starting + batch_numpy_times = numpy.tile( + numpy.arange(30, dtype=numpy.int64)[None, :], (10, 1)) + batch_numpy_values = numpy.ones([10, 30, 1]) + state = saved_model_utils.cold_start_filter( + signatures=signatures, + session=sess, + features={ + feature_keys.FilteringFeatures.TIMES: batch_numpy_times, + feature_keys.FilteringFeatures.VALUES: batch_numpy_values + } + ) + predict_times = numpy.tile( + numpy.arange(30, 45, dtype=numpy.int64)[None, :], (10, 1)) + predictions = saved_model_utils.predict_continuation( + continue_from=state, + times=predict_times, + signatures=signatures, + session=sess) + self.assertAllEqual([10, 15, 1], predictions["mean"].shape) + def test_fit_restore_fit_ar_regressor(self): def _estimator_fn(model_dir): return estimators.ARRegressor( diff --git a/tensorflow/contrib/timeseries/python/timeseries/head.py b/tensorflow/contrib/timeseries/python/timeseries/head.py index 3d7e615290..4cf6bbcfd4 100644 --- a/tensorflow/contrib/timeseries/python/timeseries/head.py +++ b/tensorflow/contrib/timeseries/python/timeseries/head.py @@ -154,8 +154,10 @@ class _TimeSeriesRegressionHead(head_lib._Head): # pylint:disable=protected-acc no_state_features = { k: v for k, v in features.items() if not k.startswith(feature_keys.State.STATE_PREFIX)} - cold_filtering_outputs = self.create_loss( - no_state_features, estimator_lib.ModeKeys.EVAL) + # Ignore any state management when cold-starting. The model's default + # start state is replicated across the batch. + cold_filtering_outputs = self.model.define_loss( + features=no_state_features, mode=estimator_lib.ModeKeys.EVAL) return estimator_lib.EstimatorSpec( mode=estimator_lib.ModeKeys.PREDICT, export_outputs={ |