diff options
author | Allen Lavoie <allenl@google.com> | 2018-08-20 17:52:45 -0700 |
---|---|---|
committer | TensorFlower Gardener <gardener@tensorflow.org> | 2018-08-20 17:58:27 -0700 |
commit | 331e53b1d0c922fdf923e17a803e66e694bd5368 (patch) | |
tree | 637eb2e8fc7a32a6df761af48bac4d1b03f86589 | |
parent | d01bdee827aa0b3e6688dc2bfd48ad65f3891e7a (diff) |
Make the "predict" time series example more user friendly.
Doesn't require an input_filename flag, uses the default example input instead.
PiperOrigin-RevId: 209518060
-rw-r--r-- | tensorflow/contrib/timeseries/examples/BUILD | 2 | ||||
-rw-r--r-- | tensorflow/contrib/timeseries/examples/predict.py | 16 |
2 files changed, 13 insertions, 5 deletions
diff --git a/tensorflow/contrib/timeseries/examples/BUILD b/tensorflow/contrib/timeseries/examples/BUILD index 355303acf6..71b0d48798 100644 --- a/tensorflow/contrib/timeseries/examples/BUILD +++ b/tensorflow/contrib/timeseries/examples/BUILD @@ -16,6 +16,7 @@ config_setting( py_binary( name = "predict", srcs = ["predict.py"], + data = ["data/period_trend.csv"], srcs_version = "PY2AND3", tags = ["no_pip"], deps = select({ @@ -31,7 +32,6 @@ py_test( name = "predict_test", timeout = "long", # Moderate but for asan srcs = ["predict_test.py"], - data = ["data/period_trend.csv"], srcs_version = "PY2AND3", tags = [ "no_windows", # TODO: needs investigation on Windows diff --git a/tensorflow/contrib/timeseries/examples/predict.py b/tensorflow/contrib/timeseries/examples/predict.py index 8147d40caa..b036911314 100644 --- a/tensorflow/contrib/timeseries/examples/predict.py +++ b/tensorflow/contrib/timeseries/examples/predict.py @@ -19,6 +19,7 @@ from __future__ import division from __future__ import print_function import argparse +import os import sys import numpy as np @@ -40,6 +41,10 @@ except ImportError: FLAGS = None +_MODULE_PATH = os.path.dirname(__file__) +_DEFAULT_DATA_FILE = os.path.join(_MODULE_PATH, "data/period_trend.csv") + + def structural_ensemble_train_and_predict(csv_file_name): # Cycle between 5 latent values over a period of 100. This leads to a very # smooth periodic component (and a small model), which is a good fit for our @@ -115,9 +120,12 @@ def main(unused_argv): if not HAS_MATPLOTLIB: raise ImportError( "Please install matplotlib to generate a plot from this example.") + input_filename = FLAGS.input_filename + if input_filename is None: + input_filename = _DEFAULT_DATA_FILE make_plot("Structural ensemble", - *structural_ensemble_train_and_predict(FLAGS.input_filename)) - make_plot("AR", *ar_train_and_predict(FLAGS.input_filename)) + *structural_ensemble_train_and_predict(input_filename)) + make_plot("AR", *ar_train_and_predict(input_filename)) pyplot.show() @@ -126,7 +134,7 @@ if __name__ == "__main__": parser.add_argument( "--input_filename", type=str, - required=True, - help="Input csv file.") + required=False, + help="Input csv file (omit to use the data/period_trend.csv).") FLAGS, unparsed = parser.parse_known_args() tf.app.run(main=main, argv=[sys.argv[0]] + unparsed) |