aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/contrib/timeseries
diff options
context:
space:
mode:
authorGravatar Allen Lavoie <allenl@google.com>2018-09-11 09:31:42 -0700
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2018-09-11 09:34:55 -0700
commit847b38406a28546991b62193278ee87910cd3d74 (patch)
tree1669f6cb995b6d09b70a6d2bf7b9180c65a540e4 /tensorflow/contrib/timeseries
parentde5ddd51e32c4630e63c0cb3e960c69f9ac77662 (diff)
TFTS: Fix an input statistics race condition
The fix is straightforward enough, although the triggering circumstances are still a bit mysterious. The unit test did fail with ubsan prior to this CL, so I'm going to leave it at that for now. PiperOrigin-RevId: 212465732
Diffstat (limited to 'tensorflow/contrib/timeseries')
-rw-r--r--tensorflow/contrib/timeseries/python/timeseries/estimators_test.py9
-rw-r--r--tensorflow/contrib/timeseries/python/timeseries/math_utils.py4
2 files changed, 11 insertions, 2 deletions
diff --git a/tensorflow/contrib/timeseries/python/timeseries/estimators_test.py b/tensorflow/contrib/timeseries/python/timeseries/estimators_test.py
index 461fe22210..83260fc59a 100644
--- a/tensorflow/contrib/timeseries/python/timeseries/estimators_test.py
+++ b/tensorflow/contrib/timeseries/python/timeseries/estimators_test.py
@@ -216,6 +216,15 @@ class TimeSeriesRegressorTest(test.TestCase):
exogenous_feature_columns=exogenous_feature_columns)
self._fit_restore_fit_test_template(_estimator_fn, dtype=dtype)
+ def test_structural_ensemble_numpy_input(self):
+ numpy_data = {"times": numpy.arange(50),
+ "values": numpy.random.normal(size=[50])}
+ estimators.StructuralEnsembleRegressor(
+ num_features=1, periodicities=[], model_dir=self.get_temp_dir(),
+ config=_SeedRunConfig()).train(
+ input_pipeline.WholeDatasetInputFn(
+ input_pipeline.NumpyReader(numpy_data)),
+ steps=1)
if __name__ == "__main__":
test.main()
diff --git a/tensorflow/contrib/timeseries/python/timeseries/math_utils.py b/tensorflow/contrib/timeseries/python/timeseries/math_utils.py
index 9b593fecbb..03da2b82e5 100644
--- a/tensorflow/contrib/timeseries/python/timeseries/math_utils.py
+++ b/tensorflow/contrib/timeseries/python/timeseries/math_utils.py
@@ -896,8 +896,8 @@ class InputStatisticsFromMiniBatch(object):
statistics.total_observation_count,
math_ops.cast(
gen_math_ops.round(
- math_ops.cast(auxiliary_variables.max_time_seen -
- statistics.start_time + 1, self._dtype) /
+ math_ops.cast(max_time_seen_assign -
+ start_time_update + 1, self._dtype) /
inter_observation_duration_estimate), dtypes.int64))
per_chunk_stat_updates = control_flow_ops.group(
overall_feature_mean_update, overall_feature_var_update,