aboutsummaryrefslogtreecommitdiffhomepage
diff options
context:
space:
mode:
authorGravatar Allen Lavoie <allenl@google.com>2018-08-20 17:52:45 -0700
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2018-08-20 17:58:27 -0700
commit331e53b1d0c922fdf923e17a803e66e694bd5368 (patch)
tree637eb2e8fc7a32a6df761af48bac4d1b03f86589
parentd01bdee827aa0b3e6688dc2bfd48ad65f3891e7a (diff)
Make the "predict" time series example more user friendly.
Doesn't require an input_filename flag, uses the default example input instead. PiperOrigin-RevId: 209518060
-rw-r--r--tensorflow/contrib/timeseries/examples/BUILD2
-rw-r--r--tensorflow/contrib/timeseries/examples/predict.py16
2 files changed, 13 insertions, 5 deletions
diff --git a/tensorflow/contrib/timeseries/examples/BUILD b/tensorflow/contrib/timeseries/examples/BUILD
index 355303acf6..71b0d48798 100644
--- a/tensorflow/contrib/timeseries/examples/BUILD
+++ b/tensorflow/contrib/timeseries/examples/BUILD
@@ -16,6 +16,7 @@ config_setting(
py_binary(
name = "predict",
srcs = ["predict.py"],
+ data = ["data/period_trend.csv"],
srcs_version = "PY2AND3",
tags = ["no_pip"],
deps = select({
@@ -31,7 +32,6 @@ py_test(
name = "predict_test",
timeout = "long", # Moderate but for asan
srcs = ["predict_test.py"],
- data = ["data/period_trend.csv"],
srcs_version = "PY2AND3",
tags = [
"no_windows", # TODO: needs investigation on Windows
diff --git a/tensorflow/contrib/timeseries/examples/predict.py b/tensorflow/contrib/timeseries/examples/predict.py
index 8147d40caa..b036911314 100644
--- a/tensorflow/contrib/timeseries/examples/predict.py
+++ b/tensorflow/contrib/timeseries/examples/predict.py
@@ -19,6 +19,7 @@ from __future__ import division
from __future__ import print_function
import argparse
+import os
import sys
import numpy as np
@@ -40,6 +41,10 @@ except ImportError:
FLAGS = None
+_MODULE_PATH = os.path.dirname(__file__)
+_DEFAULT_DATA_FILE = os.path.join(_MODULE_PATH, "data/period_trend.csv")
+
+
def structural_ensemble_train_and_predict(csv_file_name):
# Cycle between 5 latent values over a period of 100. This leads to a very
# smooth periodic component (and a small model), which is a good fit for our
@@ -115,9 +120,12 @@ def main(unused_argv):
if not HAS_MATPLOTLIB:
raise ImportError(
"Please install matplotlib to generate a plot from this example.")
+ input_filename = FLAGS.input_filename
+ if input_filename is None:
+ input_filename = _DEFAULT_DATA_FILE
make_plot("Structural ensemble",
- *structural_ensemble_train_and_predict(FLAGS.input_filename))
- make_plot("AR", *ar_train_and_predict(FLAGS.input_filename))
+ *structural_ensemble_train_and_predict(input_filename))
+ make_plot("AR", *ar_train_and_predict(input_filename))
pyplot.show()
@@ -126,7 +134,7 @@ if __name__ == "__main__":
parser.add_argument(
"--input_filename",
type=str,
- required=True,
- help="Input csv file.")
+ required=False,
+ help="Input csv file (omit to use the data/period_trend.csv).")
FLAGS, unparsed = parser.parse_known_args()
tf.app.run(main=main, argv=[sys.argv[0]] + unparsed)