aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/examples/learn
diff options
context:
space:
mode:
authorGravatar Jianwei Xie <xiejw@google.com>2017-10-18 14:06:53 -0700
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2017-10-18 14:10:35 -0700
commit71cea5ba4eafabb4a5515025bd1b6106faa0c958 (patch)
treeb3bb7118f4c6fdc036d2397789bfcc7a6e72046b /tensorflow/examples/learn
parentcadcda216ec7d6f5f3e36dfc7863634f4f03f71f (diff)
Modify the learn examples wide_and_deep to use tf.estimator.train_and_evaluate.
PiperOrigin-RevId: 172652065
Diffstat (limited to 'tensorflow/examples/learn')
-rw-r--r--tensorflow/examples/learn/wide_n_deep_tutorial.py55
1 files changed, 32 insertions, 23 deletions
diff --git a/tensorflow/examples/learn/wide_n_deep_tutorial.py b/tensorflow/examples/learn/wide_n_deep_tutorial.py
index 7b9381311c..e447b3e24e 100644
--- a/tensorflow/examples/learn/wide_n_deep_tutorial.py
+++ b/tensorflow/examples/learn/wide_n_deep_tutorial.py
@@ -107,6 +107,9 @@ deep_columns = [
]
+FLAGS = None
+
+
def maybe_download(train_data, test_data):
"""Maybe downloads training data and returns train and test file names."""
if train_data:
@@ -154,7 +157,14 @@ def build_estimator(model_dir, model_type):
def input_fn(data_file, num_epochs, shuffle):
- """Input builder function."""
+ """Returns an `input_fn` required by Estimator train/evaluate.
+
+ Args:
+ data_file: The file path to the dataset.
+ num_epochs: Number of epochs to iterate over data. If `None`, `input_fn`
+ will generate infinite stream of data.
+ shuffle: bool, whether to read the data in random order.
+ """
df_data = pd.read_csv(
tf.gfile.Open(data_file),
names=CSV_COLUMNS,
@@ -164,43 +174,42 @@ def input_fn(data_file, num_epochs, shuffle):
# remove NaN elements
df_data = df_data.dropna(how="any", axis=0)
labels = df_data["income_bracket"].apply(lambda x: ">50K" in x).astype(int)
+
return tf.estimator.inputs.pandas_input_fn(
x=df_data,
y=labels,
batch_size=100,
num_epochs=num_epochs,
shuffle=shuffle,
- num_threads=5)
+ num_threads=1)
-def train_and_eval(model_dir, model_type, train_steps, train_data, test_data):
- """Train and evaluate the model."""
- train_file_name, test_file_name = maybe_download(train_data, test_data)
+def main(_):
+ tf.logging.set_verbosity(tf.logging.INFO)
+
+ train_file_name, test_file_name = maybe_download(FLAGS.train_data,
+ FLAGS.test_data)
+
# Specify file path below if want to find the output easily
- model_dir = tempfile.mkdtemp() if not model_dir else model_dir
+ model_dir = FLAGS.model_dir if FLAGS.model_dir else tempfile.mkdtemp()
- m = build_estimator(model_dir, model_type)
- # set num_epochs to None to get infinite stream of data.
- m.train(
+ estimator = build_estimator(model_dir, FLAGS.model_type)
+
+ # `tf.estimator.TrainSpec`, `tf.estimator.EvalSpec`, and
+ # `tf.estimator.train_and_evaluate` API are available in TF 1.4.
+ train_spec = tf.estimator.TrainSpec(
input_fn=input_fn(train_file_name, num_epochs=None, shuffle=True),
- steps=train_steps)
- # set steps to None to run evaluation until all data consumed.
- results = m.evaluate(
+ max_steps=FLAGS.train_steps)
+
+ eval_spec = tf.estimator.EvalSpec(
input_fn=input_fn(test_file_name, num_epochs=1, shuffle=False),
+ # set steps to None to run evaluation until all data consumed.
steps=None)
- print("model directory = %s" % model_dir)
- for key in sorted(results):
- print("%s: %s" % (key, results[key]))
- # Manual cleanup
- shutil.rmtree(model_dir)
-
-
-FLAGS = None
+ tf.estimator.train_and_evaluate(estimator, train_spec, eval_spec)
-def main(_):
- train_and_eval(FLAGS.model_dir, FLAGS.model_type, FLAGS.train_steps,
- FLAGS.train_data, FLAGS.test_data)
+ # Manual cleanup
+ shutil.rmtree(model_dir)
if __name__ == "__main__":