diff options
author | Allen Lavoie <allenl@google.com> | 2018-04-02 11:53:37 -0700 |
---|---|---|
committer | TensorFlower Gardener <gardener@tensorflow.org> | 2018-04-02 11:55:51 -0700 |
commit | 6c4095c7353666c4b75ce189e68860be1159b40a (patch) | |
tree | 676f1c857f12d5bde66d4815ad186ea8f76d6e8d /tensorflow/contrib/timeseries | |
parent | 5d81b72b9c1a7edd1a84c13b1dc753b310545e56 (diff) |
TFTS: Clean up the cold start SignatureDef.
Removes state where it wasn't used.
PiperOrigin-RevId: 191324834
Diffstat (limited to 'tensorflow/contrib/timeseries')
-rw-r--r-- | tensorflow/contrib/timeseries/python/timeseries/estimators_test.py | 7 | ||||
-rw-r--r-- | tensorflow/contrib/timeseries/python/timeseries/head.py | 12 |
2 files changed, 18 insertions, 1 deletions
diff --git a/tensorflow/contrib/timeseries/python/timeseries/estimators_test.py b/tensorflow/contrib/timeseries/python/timeseries/estimators_test.py index 51d0c0ca3f..9f161c1695 100644 --- a/tensorflow/contrib/timeseries/python/timeseries/estimators_test.py +++ b/tensorflow/contrib/timeseries/python/timeseries/estimators_test.py @@ -19,6 +19,7 @@ from __future__ import print_function import tempfile import numpy +import six from tensorflow.contrib.timeseries.python.timeseries import ar_model from tensorflow.contrib.timeseries.python.timeseries import estimators @@ -127,6 +128,12 @@ class TimeSeriesRegressorTest(test.TestCase): session=sess) # Test cold starting + six.assertCountEqual( + self, + [feature_keys.FilteringFeatures.TIMES, + feature_keys.FilteringFeatures.VALUES], + signatures.signature_def[ + feature_keys.SavedModelLabels.COLD_START_FILTER].inputs.keys()) batch_numpy_times = numpy.tile( numpy.arange(30, dtype=numpy.int64)[None, :], (10, 1)) batch_numpy_values = numpy.ones([10, 30, 1]) diff --git a/tensorflow/contrib/timeseries/python/timeseries/head.py b/tensorflow/contrib/timeseries/python/timeseries/head.py index 4cf6bbcfd4..71085f9de8 100644 --- a/tensorflow/contrib/timeseries/python/timeseries/head.py +++ b/tensorflow/contrib/timeseries/python/timeseries/head.py @@ -58,6 +58,16 @@ def time_series_regression_head(model, input_statistics_generator) +class _NoStatePredictOutput(export_lib.PredictOutput): + + def as_signature_def(self, receiver_tensors): + no_state_receiver_tensors = { + key: value for key, value in receiver_tensors.items() + if not key.startswith(feature_keys.State.STATE_PREFIX)} + return super(_NoStatePredictOutput, self).as_signature_def( + receiver_tensors=no_state_receiver_tensors) + + class _TimeSeriesRegressionHead(head_lib._Head): # pylint:disable=protected-access """See `time_series_regression_head`.""" @@ -167,7 +177,7 @@ class _TimeSeriesRegressionHead(head_lib._Head): # pylint:disable=protected-acc export_lib.PredictOutput( state_to_dictionary(filtering_outputs.end_state)), feature_keys.SavedModelLabels.COLD_START_FILTER: - export_lib.PredictOutput( + _NoStatePredictOutput( state_to_dictionary(cold_filtering_outputs.end_state)) }, # Likely unused, but it is necessary to return `predictions` to satisfy |