diff options
author | 2017-10-18 14:06:53 -0700 | |
---|---|---|
committer | 2017-10-18 14:10:35 -0700 | |
commit | 71cea5ba4eafabb4a5515025bd1b6106faa0c958 (patch) | |
tree | b3bb7118f4c6fdc036d2397789bfcc7a6e72046b /tensorflow/examples/learn | |
parent | cadcda216ec7d6f5f3e36dfc7863634f4f03f71f (diff) |
Modify the learn examples wide_and_deep to use tf.estimator.train_and_evaluate.
PiperOrigin-RevId: 172652065
Diffstat (limited to 'tensorflow/examples/learn')
-rw-r--r-- | tensorflow/examples/learn/wide_n_deep_tutorial.py | 55 |
1 files changed, 32 insertions, 23 deletions
diff --git a/tensorflow/examples/learn/wide_n_deep_tutorial.py b/tensorflow/examples/learn/wide_n_deep_tutorial.py index 7b9381311c..e447b3e24e 100644 --- a/tensorflow/examples/learn/wide_n_deep_tutorial.py +++ b/tensorflow/examples/learn/wide_n_deep_tutorial.py @@ -107,6 +107,9 @@ deep_columns = [ ] +FLAGS = None + + def maybe_download(train_data, test_data): """Maybe downloads training data and returns train and test file names.""" if train_data: @@ -154,7 +157,14 @@ def build_estimator(model_dir, model_type): def input_fn(data_file, num_epochs, shuffle): - """Input builder function.""" + """Returns an `input_fn` required by Estimator train/evaluate. + + Args: + data_file: The file path to the dataset. + num_epochs: Number of epochs to iterate over data. If `None`, `input_fn` + will generate infinite stream of data. + shuffle: bool, whether to read the data in random order. + """ df_data = pd.read_csv( tf.gfile.Open(data_file), names=CSV_COLUMNS, @@ -164,43 +174,42 @@ def input_fn(data_file, num_epochs, shuffle): # remove NaN elements df_data = df_data.dropna(how="any", axis=0) labels = df_data["income_bracket"].apply(lambda x: ">50K" in x).astype(int) + return tf.estimator.inputs.pandas_input_fn( x=df_data, y=labels, batch_size=100, num_epochs=num_epochs, shuffle=shuffle, - num_threads=5) + num_threads=1) -def train_and_eval(model_dir, model_type, train_steps, train_data, test_data): - """Train and evaluate the model.""" - train_file_name, test_file_name = maybe_download(train_data, test_data) +def main(_): + tf.logging.set_verbosity(tf.logging.INFO) + + train_file_name, test_file_name = maybe_download(FLAGS.train_data, + FLAGS.test_data) + # Specify file path below if want to find the output easily - model_dir = tempfile.mkdtemp() if not model_dir else model_dir + model_dir = FLAGS.model_dir if FLAGS.model_dir else tempfile.mkdtemp() - m = build_estimator(model_dir, model_type) - # set num_epochs to None to get infinite stream of data. - m.train( + estimator = build_estimator(model_dir, FLAGS.model_type) + + # `tf.estimator.TrainSpec`, `tf.estimator.EvalSpec`, and + # `tf.estimator.train_and_evaluate` API are available in TF 1.4. + train_spec = tf.estimator.TrainSpec( input_fn=input_fn(train_file_name, num_epochs=None, shuffle=True), - steps=train_steps) - # set steps to None to run evaluation until all data consumed. - results = m.evaluate( + max_steps=FLAGS.train_steps) + + eval_spec = tf.estimator.EvalSpec( input_fn=input_fn(test_file_name, num_epochs=1, shuffle=False), + # set steps to None to run evaluation until all data consumed. steps=None) - print("model directory = %s" % model_dir) - for key in sorted(results): - print("%s: %s" % (key, results[key])) - # Manual cleanup - shutil.rmtree(model_dir) - - -FLAGS = None + tf.estimator.train_and_evaluate(estimator, train_spec, eval_spec) -def main(_): - train_and_eval(FLAGS.model_dir, FLAGS.model_type, FLAGS.train_steps, - FLAGS.train_data, FLAGS.test_data) + # Manual cleanup + shutil.rmtree(model_dir) if __name__ == "__main__": |