aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/contrib/timeseries
diff options
context:
space:
mode:
authorGravatar Allen Lavoie <allenl@google.com>2018-04-02 11:53:37 -0700
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2018-04-02 11:55:51 -0700
commit6c4095c7353666c4b75ce189e68860be1159b40a (patch)
tree676f1c857f12d5bde66d4815ad186ea8f76d6e8d /tensorflow/contrib/timeseries
parent5d81b72b9c1a7edd1a84c13b1dc753b310545e56 (diff)
TFTS: Clean up the cold start SignatureDef.
Removes state where it wasn't used. PiperOrigin-RevId: 191324834
Diffstat (limited to 'tensorflow/contrib/timeseries')
-rw-r--r--tensorflow/contrib/timeseries/python/timeseries/estimators_test.py7
-rw-r--r--tensorflow/contrib/timeseries/python/timeseries/head.py12
2 files changed, 18 insertions, 1 deletions
diff --git a/tensorflow/contrib/timeseries/python/timeseries/estimators_test.py b/tensorflow/contrib/timeseries/python/timeseries/estimators_test.py
index 51d0c0ca3f..9f161c1695 100644
--- a/tensorflow/contrib/timeseries/python/timeseries/estimators_test.py
+++ b/tensorflow/contrib/timeseries/python/timeseries/estimators_test.py
@@ -19,6 +19,7 @@ from __future__ import print_function
import tempfile
import numpy
+import six
from tensorflow.contrib.timeseries.python.timeseries import ar_model
from tensorflow.contrib.timeseries.python.timeseries import estimators
@@ -127,6 +128,12 @@ class TimeSeriesRegressorTest(test.TestCase):
session=sess)
# Test cold starting
+ six.assertCountEqual(
+ self,
+ [feature_keys.FilteringFeatures.TIMES,
+ feature_keys.FilteringFeatures.VALUES],
+ signatures.signature_def[
+ feature_keys.SavedModelLabels.COLD_START_FILTER].inputs.keys())
batch_numpy_times = numpy.tile(
numpy.arange(30, dtype=numpy.int64)[None, :], (10, 1))
batch_numpy_values = numpy.ones([10, 30, 1])
diff --git a/tensorflow/contrib/timeseries/python/timeseries/head.py b/tensorflow/contrib/timeseries/python/timeseries/head.py
index 4cf6bbcfd4..71085f9de8 100644
--- a/tensorflow/contrib/timeseries/python/timeseries/head.py
+++ b/tensorflow/contrib/timeseries/python/timeseries/head.py
@@ -58,6 +58,16 @@ def time_series_regression_head(model,
input_statistics_generator)
+class _NoStatePredictOutput(export_lib.PredictOutput):
+
+ def as_signature_def(self, receiver_tensors):
+ no_state_receiver_tensors = {
+ key: value for key, value in receiver_tensors.items()
+ if not key.startswith(feature_keys.State.STATE_PREFIX)}
+ return super(_NoStatePredictOutput, self).as_signature_def(
+ receiver_tensors=no_state_receiver_tensors)
+
+
class _TimeSeriesRegressionHead(head_lib._Head): # pylint:disable=protected-access
"""See `time_series_regression_head`."""
@@ -167,7 +177,7 @@ class _TimeSeriesRegressionHead(head_lib._Head): # pylint:disable=protected-acc
export_lib.PredictOutput(
state_to_dictionary(filtering_outputs.end_state)),
feature_keys.SavedModelLabels.COLD_START_FILTER:
- export_lib.PredictOutput(
+ _NoStatePredictOutput(
state_to_dictionary(cold_filtering_outputs.end_state))
},
# Likely unused, but it is necessary to return `predictions` to satisfy