aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/contrib/timeseries
diff options
context:
space:
mode:
authorGravatar Allen Lavoie <allenl@google.com>2018-07-03 14:02:22 -0700
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2018-07-03 14:07:17 -0700
commitb6ebb965937a62e15299cec7c91896fd011ce416 (patch)
tree29da28569f1e7bdf5e3ec1c5fc9751f7b51a6b6e /tensorflow/contrib/timeseries
parentd04622b53443d61227880c6a930ab981a13bc83a (diff)
TFTS: Add a head config option to StructuralEnsembleRegressor
Adds a unit test for OneShotHead, fiddles with the train_op to deal with a dtype error the new unit tests uncovered. PiperOrigin-RevId: 203179477
Diffstat (limited to 'tensorflow/contrib/timeseries')
-rw-r--r--tensorflow/contrib/timeseries/python/timeseries/BUILD1
-rw-r--r--tensorflow/contrib/timeseries/python/timeseries/estimators.py13
-rw-r--r--tensorflow/contrib/timeseries/python/timeseries/head.py14
-rw-r--r--tensorflow/contrib/timeseries/python/timeseries/head_test.py56
4 files changed, 55 insertions, 29 deletions
diff --git a/tensorflow/contrib/timeseries/python/timeseries/BUILD b/tensorflow/contrib/timeseries/python/timeseries/BUILD
index e4963596d3..ec9a7861e7 100644
--- a/tensorflow/contrib/timeseries/python/timeseries/BUILD
+++ b/tensorflow/contrib/timeseries/python/timeseries/BUILD
@@ -184,6 +184,7 @@ py_test(
"//tensorflow/python/saved_model:loader",
"//tensorflow/python/saved_model:tag_constants",
"//third_party/py/numpy",
+ "@absl_py//absl/testing:parameterized",
"@six_archive//:six",
],
)
diff --git a/tensorflow/contrib/timeseries/python/timeseries/estimators.py b/tensorflow/contrib/timeseries/python/timeseries/estimators.py
index 4ec8d26116..769183f40a 100644
--- a/tensorflow/contrib/timeseries/python/timeseries/estimators.py
+++ b/tensorflow/contrib/timeseries/python/timeseries/estimators.py
@@ -288,7 +288,7 @@ class StateSpaceRegressor(TimeSeriesRegressor):
"""An Estimator for general state space models."""
def __init__(self, model, state_manager=None, optimizer=None, model_dir=None,
- config=None):
+ config=None, head_type=ts_head_lib.TimeSeriesRegressionHead):
"""See TimeSeriesRegressor. Uses the ChainingStateManager by default."""
if not isinstance(model, state_space_model.StateSpaceModel):
raise ValueError(
@@ -301,7 +301,8 @@ class StateSpaceRegressor(TimeSeriesRegressor):
state_manager=state_manager,
optimizer=optimizer,
model_dir=model_dir,
- config=config)
+ config=config,
+ head_type=head_type)
class StructuralEnsembleRegressor(StateSpaceRegressor):
@@ -344,7 +345,8 @@ class StructuralEnsembleRegressor(StateSpaceRegressor):
anomaly_prior_probability=None,
optimizer=None,
model_dir=None,
- config=None):
+ config=None,
+ head_type=ts_head_lib.TimeSeriesRegressionHead):
"""Initialize the Estimator.
Args:
@@ -401,6 +403,8 @@ class StructuralEnsembleRegressor(StateSpaceRegressor):
from tf.train.Optimizer. Defaults to Adam with step size 0.02.
model_dir: See `Estimator`.
config: See `Estimator`.
+ head_type: The kind of head to use for the model (inheriting from
+ `TimeSeriesRegressionHead`).
"""
if anomaly_prior_probability is not None:
filtering_postprocessor = StateInterpolatingAnomalyDetector(
@@ -424,4 +428,5 @@ class StructuralEnsembleRegressor(StateSpaceRegressor):
model=model,
optimizer=optimizer,
model_dir=model_dir,
- config=config)
+ config=config,
+ head_type=head_type)
diff --git a/tensorflow/contrib/timeseries/python/timeseries/head.py b/tensorflow/contrib/timeseries/python/timeseries/head.py
index f236329fdb..8686a803e5 100644
--- a/tensorflow/contrib/timeseries/python/timeseries/head.py
+++ b/tensorflow/contrib/timeseries/python/timeseries/head.py
@@ -19,11 +19,7 @@ from __future__ import print_function
import re
-from tensorflow.python.training import training_util
-from tensorflow.contrib.layers.python.layers import optimizers
-
from tensorflow.contrib.timeseries.python.timeseries import feature_keys
-
from tensorflow.python.estimator import estimator_lib
from tensorflow.python.estimator.canned import head as head_lib
from tensorflow.python.estimator.canned import metric_keys
@@ -35,8 +31,9 @@ from tensorflow.python.ops import control_flow_ops
from tensorflow.python.ops import math_ops
from tensorflow.python.ops import state_ops
from tensorflow.python.ops import variable_scope
-from tensorflow.python.util import nest
from tensorflow.python.summary import summary
+from tensorflow.python.training import training_util
+from tensorflow.python.util import nest
class _NoStatePredictOutput(export_lib.PredictOutput):
@@ -102,12 +99,9 @@ class TimeSeriesRegressionHead(head_lib._Head): # pylint:disable=protected-acce
use_resource=True):
model_outputs = self.create_loss(features, mode)
- train_op = optimizers.optimize_loss(
+ train_op = self.optimizer.minimize(
model_outputs.loss,
- global_step=training_util.get_global_step(),
- optimizer=self.optimizer,
- # Learning rate is set in the Optimizer object
- learning_rate=None)
+ global_step=training_util.get_global_step())
return estimator_lib.EstimatorSpec(
loss=model_outputs.loss,
mode=mode,
diff --git a/tensorflow/contrib/timeseries/python/timeseries/head_test.py b/tensorflow/contrib/timeseries/python/timeseries/head_test.py
index ed8f29c321..78c2cec21c 100644
--- a/tensorflow/contrib/timeseries/python/timeseries/head_test.py
+++ b/tensorflow/contrib/timeseries/python/timeseries/head_test.py
@@ -18,6 +18,9 @@ from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
+import os
+
+from absl.testing import parameterized
import numpy
import six
@@ -317,10 +320,38 @@ class PredictFeatureCheckingTests(test.TestCase):
mode=estimator_lib.ModeKeys.PREDICT)
-class OneShotTests(test.TestCase):
-
- def test_one_shot_prediction_head_export(self):
- model_dir = self.get_temp_dir()
+def _custom_time_series_regressor(
+ model_dir, head_type, exogenous_feature_columns):
+ return ts_estimators.TimeSeriesRegressor(
+ model=lstm_example._LSTMModel(
+ num_features=5, num_units=128,
+ exogenous_feature_columns=exogenous_feature_columns),
+ optimizer=adam.AdamOptimizer(0.001),
+ config=estimator_lib.RunConfig(tf_random_seed=4),
+ state_manager=state_management.ChainingStateManager(),
+ head_type=head_type,
+ model_dir=model_dir)
+
+
+def _structural_ensemble_regressor(
+ model_dir, head_type, exogenous_feature_columns):
+ return ts_estimators.StructuralEnsembleRegressor(
+ periodicities=None,
+ num_features=5,
+ exogenous_feature_columns=exogenous_feature_columns,
+ head_type=head_type,
+ model_dir=model_dir)
+
+
+class OneShotTests(parameterized.TestCase):
+
+ @parameterized.named_parameters(
+ {"testcase_name": "custom_time_series_regressor",
+ "estimator_factory": _custom_time_series_regressor},
+ {"testcase_name": "structural_ensemble_regressor",
+ "estimator_factory": _structural_ensemble_regressor})
+ def test_one_shot_prediction_head_export(self, estimator_factory):
+ model_dir = os.path.join(test.get_temp_dir(), str(ops.uid()))
categorical_column = feature_column.categorical_column_with_hash_bucket(
key="categorical_exogenous_feature", hash_bucket_size=16)
exogenous_feature_columns = [
@@ -328,15 +359,10 @@ class OneShotTests(test.TestCase):
"2d_exogenous_feature", shape=(2,)),
feature_column.embedding_column(
categorical_column=categorical_column, dimension=10)]
- estimator = ts_estimators.TimeSeriesRegressor(
- model=lstm_example._LSTMModel(
- num_features=5, num_units=128,
- exogenous_feature_columns=exogenous_feature_columns),
- optimizer=adam.AdamOptimizer(0.001),
- config=estimator_lib.RunConfig(tf_random_seed=4),
- state_manager=state_management.ChainingStateManager(),
- head_type=ts_head_lib.OneShotPredictionHead,
- model_dir=model_dir)
+ estimator = estimator_factory(
+ model_dir=model_dir,
+ exogenous_feature_columns=exogenous_feature_columns,
+ head_type=ts_head_lib.OneShotPredictionHead)
train_features = {
feature_keys.TrainEvalFeatures.TIMES: numpy.arange(
20, dtype=numpy.int64),
@@ -351,7 +377,7 @@ class OneShotTests(test.TestCase):
num_threads=1, batch_size=16, window_size=16)
estimator.train(input_fn=train_input_fn, steps=5)
input_receiver_fn = estimator.build_raw_serving_input_receiver_fn()
- export_location = estimator.export_savedmodel(self.get_temp_dir(),
+ export_location = estimator.export_savedmodel(test.get_temp_dir(),
input_receiver_fn)
graph = ops.Graph()
with graph.as_default():
@@ -385,7 +411,7 @@ class OneShotTests(test.TestCase):
for output_key, output_value
in predict_signature.outputs.items()}
output = session.run(fetches, feed_dict=feeds)
- self.assertAllEqual((2, 15, 5), output["mean"].shape)
+ self.assertEqual((2, 15, 5), output["mean"].shape)
if __name__ == "__main__":