From 847b38406a28546991b62193278ee87910cd3d74 Mon Sep 17 00:00:00 2001 From: Allen Lavoie Date: Tue, 11 Sep 2018 09:31:42 -0700 Subject: TFTS: Fix an input statistics race condition The fix is straightforward enough, although the triggering circumstances are still a bit mysterious. The unit test did fail with ubsan prior to this CL, so I'm going to leave it at that for now. PiperOrigin-RevId: 212465732 --- .../contrib/timeseries/python/timeseries/estimators_test.py | 9 +++++++++ tensorflow/contrib/timeseries/python/timeseries/math_utils.py | 4 ++-- 2 files changed, 11 insertions(+), 2 deletions(-) (limited to 'tensorflow/contrib/timeseries') diff --git a/tensorflow/contrib/timeseries/python/timeseries/estimators_test.py b/tensorflow/contrib/timeseries/python/timeseries/estimators_test.py index 461fe22210..83260fc59a 100644 --- a/tensorflow/contrib/timeseries/python/timeseries/estimators_test.py +++ b/tensorflow/contrib/timeseries/python/timeseries/estimators_test.py @@ -216,6 +216,15 @@ class TimeSeriesRegressorTest(test.TestCase): exogenous_feature_columns=exogenous_feature_columns) self._fit_restore_fit_test_template(_estimator_fn, dtype=dtype) + def test_structural_ensemble_numpy_input(self): + numpy_data = {"times": numpy.arange(50), + "values": numpy.random.normal(size=[50])} + estimators.StructuralEnsembleRegressor( + num_features=1, periodicities=[], model_dir=self.get_temp_dir(), + config=_SeedRunConfig()).train( + input_pipeline.WholeDatasetInputFn( + input_pipeline.NumpyReader(numpy_data)), + steps=1) if __name__ == "__main__": test.main() diff --git a/tensorflow/contrib/timeseries/python/timeseries/math_utils.py b/tensorflow/contrib/timeseries/python/timeseries/math_utils.py index 9b593fecbb..03da2b82e5 100644 --- a/tensorflow/contrib/timeseries/python/timeseries/math_utils.py +++ b/tensorflow/contrib/timeseries/python/timeseries/math_utils.py @@ -896,8 +896,8 @@ class InputStatisticsFromMiniBatch(object): statistics.total_observation_count, math_ops.cast( gen_math_ops.round( - math_ops.cast(auxiliary_variables.max_time_seen - - statistics.start_time + 1, self._dtype) / + math_ops.cast(max_time_seen_assign - + start_time_update + 1, self._dtype) / inter_observation_duration_estimate), dtypes.int64)) per_chunk_stat_updates = control_flow_ops.group( overall_feature_mean_update, overall_feature_var_update, -- cgit v1.2.3