diff options
author | A. Unique TensorFlower <gardener@tensorflow.org> | 2017-06-26 14:00:17 -0700 |
---|---|---|
committer | TensorFlower Gardener <gardener@tensorflow.org> | 2017-06-26 14:04:35 -0700 |
commit | 1fa73c53ab95693f070ce70e6be0c644d83c163a (patch) | |
tree | ffbedf825daf1f3453c695a433c8a9cdf93f6019 /tensorflow/examples/image_retraining | |
parent | b13e96e21c1229a905a623111dd89d2bd0cba53b (diff) |
Automated g4 rollback of changelist 160182040
PiperOrigin-RevId: 160190881
Diffstat (limited to 'tensorflow/examples/image_retraining')
-rw-r--r-- | tensorflow/examples/image_retraining/retrain.py | 43 |
1 files changed, 5 insertions, 38 deletions
diff --git a/tensorflow/examples/image_retraining/retrain.py b/tensorflow/examples/image_retraining/retrain.py index 44a3097d80..8e3b1a3a36 100644 --- a/tensorflow/examples/image_retraining/retrain.py +++ b/tensorflow/examples/image_retraining/retrain.py @@ -800,27 +800,11 @@ def add_evaluation_step(result_tensor, ground_truth_tensor): return evaluation_step, prediction -def save_graph_to_file(sess, graph, graph_file_name): - output_graph_def = graph_util.convert_variables_to_constants( - sess, graph.as_graph_def(), [FLAGS.final_tensor_name]) - with gfile.FastGFile(graph_file_name, 'wb') as f: - f.write(output_graph_def.SerializeToString()) - return - - -def prepare_file_system(): +def main(_): # Setup the directory we'll write summaries to for TensorBoard if tf.gfile.Exists(FLAGS.summaries_dir): tf.gfile.DeleteRecursively(FLAGS.summaries_dir) tf.gfile.MakeDirs(FLAGS.summaries_dir) - if FLAGS.intermediate_store_frequency > 0: - ensure_dir_exists(FLAGS.intermediate_output_graphs_dir) - return - - -def main(_): - # Prepare necessary directories that can be used during training - prepare_file_system() # Set up the pre-trained graph. maybe_download_and_extract() @@ -933,14 +917,6 @@ def main(_): (datetime.now(), i, validation_accuracy * 100, len(validation_bottlenecks))) - # Store intermediate results - intermediate_frequency = FLAGS.intermediate_store_frequency - - if intermediate_frequency > 0 and (i % intermediate_frequency == 0) and i > 0: - intermediate_file_name = FLAGS.intermediate_output_graphs_dir + 'intermediate_' + str(i) + '.pb' - print('Save intermediate result to : ' + intermediate_file_name) - save_graph_to_file(sess, graph, intermediate_file_name) - # We've completed all our training, so run a final test evaluation on # some new images we haven't used before. test_bottlenecks, test_ground_truth, test_filenames = ( @@ -964,7 +940,10 @@ def main(_): # Write out the trained graph and labels with the weights stored as # constants. - save_graph_to_file(sess, graph, FLAGS.output_graph) + output_graph_def = graph_util.convert_variables_to_constants( + sess, graph.as_graph_def(), [FLAGS.final_tensor_name]) + with gfile.FastGFile(FLAGS.output_graph, 'wb') as f: + f.write(output_graph_def.SerializeToString()) with gfile.FastGFile(FLAGS.output_labels, 'w') as f: f.write('\n'.join(image_lists.keys()) + '\n') @@ -984,18 +963,6 @@ if __name__ == '__main__': help='Where to save the trained graph.' ) parser.add_argument( - '--intermediate_output_graphs_dir', - type=str, - default='/tmp/intermediate_graph/', - help='Where to save the intermediate graphs.' - ) - parser.add_argument( - '--intermediate_store_frequency', - type=int, - default=0, - help='How many steps to store intermediate graph. If "0" then will not store.' - ) - parser.add_argument( '--output_labels', type=str, default='/tmp/output_labels.txt', |