diff options
author | A. Unique TensorFlower <gardener@tensorflow.org> | 2017-06-27 16:33:00 -0700 |
---|---|---|
committer | TensorFlower Gardener <gardener@tensorflow.org> | 2017-06-27 16:37:09 -0700 |
commit | 50b999a8336d19400ab75aea66fe46eca2f5fe0b (patch) | |
tree | 7cba4f4af6b131c253b65ff9f2923e851184668c /tensorflow/examples/image_retraining | |
parent | d6d58a3a1785785679af56c0f8f131e7312b8226 (diff) |
Merge changes from github.
PiperOrigin-RevId: 160344052
Diffstat (limited to 'tensorflow/examples/image_retraining')
-rw-r--r-- | tensorflow/examples/image_retraining/retrain.py | 43 |
1 files changed, 38 insertions, 5 deletions
diff --git a/tensorflow/examples/image_retraining/retrain.py b/tensorflow/examples/image_retraining/retrain.py index 8e3b1a3a36..44a3097d80 100644 --- a/tensorflow/examples/image_retraining/retrain.py +++ b/tensorflow/examples/image_retraining/retrain.py @@ -800,11 +800,27 @@ def add_evaluation_step(result_tensor, ground_truth_tensor): return evaluation_step, prediction -def main(_): +def save_graph_to_file(sess, graph, graph_file_name): + output_graph_def = graph_util.convert_variables_to_constants( + sess, graph.as_graph_def(), [FLAGS.final_tensor_name]) + with gfile.FastGFile(graph_file_name, 'wb') as f: + f.write(output_graph_def.SerializeToString()) + return + + +def prepare_file_system(): # Setup the directory we'll write summaries to for TensorBoard if tf.gfile.Exists(FLAGS.summaries_dir): tf.gfile.DeleteRecursively(FLAGS.summaries_dir) tf.gfile.MakeDirs(FLAGS.summaries_dir) + if FLAGS.intermediate_store_frequency > 0: + ensure_dir_exists(FLAGS.intermediate_output_graphs_dir) + return + + +def main(_): + # Prepare necessary directories that can be used during training + prepare_file_system() # Set up the pre-trained graph. maybe_download_and_extract() @@ -917,6 +933,14 @@ def main(_): (datetime.now(), i, validation_accuracy * 100, len(validation_bottlenecks))) + # Store intermediate results + intermediate_frequency = FLAGS.intermediate_store_frequency + + if intermediate_frequency > 0 and (i % intermediate_frequency == 0) and i > 0: + intermediate_file_name = FLAGS.intermediate_output_graphs_dir + 'intermediate_' + str(i) + '.pb' + print('Save intermediate result to : ' + intermediate_file_name) + save_graph_to_file(sess, graph, intermediate_file_name) + # We've completed all our training, so run a final test evaluation on # some new images we haven't used before. test_bottlenecks, test_ground_truth, test_filenames = ( @@ -940,10 +964,7 @@ def main(_): # Write out the trained graph and labels with the weights stored as # constants. - output_graph_def = graph_util.convert_variables_to_constants( - sess, graph.as_graph_def(), [FLAGS.final_tensor_name]) - with gfile.FastGFile(FLAGS.output_graph, 'wb') as f: - f.write(output_graph_def.SerializeToString()) + save_graph_to_file(sess, graph, FLAGS.output_graph) with gfile.FastGFile(FLAGS.output_labels, 'w') as f: f.write('\n'.join(image_lists.keys()) + '\n') @@ -963,6 +984,18 @@ if __name__ == '__main__': help='Where to save the trained graph.' ) parser.add_argument( + '--intermediate_output_graphs_dir', + type=str, + default='/tmp/intermediate_graph/', + help='Where to save the intermediate graphs.' + ) + parser.add_argument( + '--intermediate_store_frequency', + type=int, + default=0, + help='How many steps to store intermediate graph. If "0" then will not store.' + ) + parser.add_argument( '--output_labels', type=str, default='/tmp/output_labels.txt', |