aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/examples/image_retraining/retrain.py
diff options
context:
space:
mode:
Diffstat (limited to 'tensorflow/examples/image_retraining/retrain.py')
-rw-r--r--tensorflow/examples/image_retraining/retrain.py43
1 files changed, 38 insertions, 5 deletions
diff --git a/tensorflow/examples/image_retraining/retrain.py b/tensorflow/examples/image_retraining/retrain.py
index 8e3b1a3a36..44a3097d80 100644
--- a/tensorflow/examples/image_retraining/retrain.py
+++ b/tensorflow/examples/image_retraining/retrain.py
@@ -800,11 +800,27 @@ def add_evaluation_step(result_tensor, ground_truth_tensor):
return evaluation_step, prediction
-def main(_):
+def save_graph_to_file(sess, graph, graph_file_name):
+ output_graph_def = graph_util.convert_variables_to_constants(
+ sess, graph.as_graph_def(), [FLAGS.final_tensor_name])
+ with gfile.FastGFile(graph_file_name, 'wb') as f:
+ f.write(output_graph_def.SerializeToString())
+ return
+
+
+def prepare_file_system():
# Setup the directory we'll write summaries to for TensorBoard
if tf.gfile.Exists(FLAGS.summaries_dir):
tf.gfile.DeleteRecursively(FLAGS.summaries_dir)
tf.gfile.MakeDirs(FLAGS.summaries_dir)
+ if FLAGS.intermediate_store_frequency > 0:
+ ensure_dir_exists(FLAGS.intermediate_output_graphs_dir)
+ return
+
+
+def main(_):
+ # Prepare necessary directories that can be used during training
+ prepare_file_system()
# Set up the pre-trained graph.
maybe_download_and_extract()
@@ -917,6 +933,14 @@ def main(_):
(datetime.now(), i, validation_accuracy * 100,
len(validation_bottlenecks)))
+ # Store intermediate results
+ intermediate_frequency = FLAGS.intermediate_store_frequency
+
+ if intermediate_frequency > 0 and (i % intermediate_frequency == 0) and i > 0:
+ intermediate_file_name = FLAGS.intermediate_output_graphs_dir + 'intermediate_' + str(i) + '.pb'
+ print('Save intermediate result to : ' + intermediate_file_name)
+ save_graph_to_file(sess, graph, intermediate_file_name)
+
# We've completed all our training, so run a final test evaluation on
# some new images we haven't used before.
test_bottlenecks, test_ground_truth, test_filenames = (
@@ -940,10 +964,7 @@ def main(_):
# Write out the trained graph and labels with the weights stored as
# constants.
- output_graph_def = graph_util.convert_variables_to_constants(
- sess, graph.as_graph_def(), [FLAGS.final_tensor_name])
- with gfile.FastGFile(FLAGS.output_graph, 'wb') as f:
- f.write(output_graph_def.SerializeToString())
+ save_graph_to_file(sess, graph, FLAGS.output_graph)
with gfile.FastGFile(FLAGS.output_labels, 'w') as f:
f.write('\n'.join(image_lists.keys()) + '\n')
@@ -963,6 +984,18 @@ if __name__ == '__main__':
help='Where to save the trained graph.'
)
parser.add_argument(
+ '--intermediate_output_graphs_dir',
+ type=str,
+ default='/tmp/intermediate_graph/',
+ help='Where to save the intermediate graphs.'
+ )
+ parser.add_argument(
+ '--intermediate_store_frequency',
+ type=int,
+ default=0,
+ help='How many steps to store intermediate graph. If "0" then will not store.'
+ )
+ parser.add_argument(
'--output_labels',
type=str,
default='/tmp/output_labels.txt',