aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/examples/image_retraining
diff options
context:
space:
mode:
authorGravatar A. Unique TensorFlower <gardener@tensorflow.org>2017-06-26 14:00:17 -0700
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2017-06-26 14:04:35 -0700
commit1fa73c53ab95693f070ce70e6be0c644d83c163a (patch)
treeffbedf825daf1f3453c695a433c8a9cdf93f6019 /tensorflow/examples/image_retraining
parentb13e96e21c1229a905a623111dd89d2bd0cba53b (diff)
Automated g4 rollback of changelist 160182040
PiperOrigin-RevId: 160190881
Diffstat (limited to 'tensorflow/examples/image_retraining')
-rw-r--r--tensorflow/examples/image_retraining/retrain.py43
1 files changed, 5 insertions, 38 deletions
diff --git a/tensorflow/examples/image_retraining/retrain.py b/tensorflow/examples/image_retraining/retrain.py
index 44a3097d80..8e3b1a3a36 100644
--- a/tensorflow/examples/image_retraining/retrain.py
+++ b/tensorflow/examples/image_retraining/retrain.py
@@ -800,27 +800,11 @@ def add_evaluation_step(result_tensor, ground_truth_tensor):
return evaluation_step, prediction
-def save_graph_to_file(sess, graph, graph_file_name):
- output_graph_def = graph_util.convert_variables_to_constants(
- sess, graph.as_graph_def(), [FLAGS.final_tensor_name])
- with gfile.FastGFile(graph_file_name, 'wb') as f:
- f.write(output_graph_def.SerializeToString())
- return
-
-
-def prepare_file_system():
+def main(_):
# Setup the directory we'll write summaries to for TensorBoard
if tf.gfile.Exists(FLAGS.summaries_dir):
tf.gfile.DeleteRecursively(FLAGS.summaries_dir)
tf.gfile.MakeDirs(FLAGS.summaries_dir)
- if FLAGS.intermediate_store_frequency > 0:
- ensure_dir_exists(FLAGS.intermediate_output_graphs_dir)
- return
-
-
-def main(_):
- # Prepare necessary directories that can be used during training
- prepare_file_system()
# Set up the pre-trained graph.
maybe_download_and_extract()
@@ -933,14 +917,6 @@ def main(_):
(datetime.now(), i, validation_accuracy * 100,
len(validation_bottlenecks)))
- # Store intermediate results
- intermediate_frequency = FLAGS.intermediate_store_frequency
-
- if intermediate_frequency > 0 and (i % intermediate_frequency == 0) and i > 0:
- intermediate_file_name = FLAGS.intermediate_output_graphs_dir + 'intermediate_' + str(i) + '.pb'
- print('Save intermediate result to : ' + intermediate_file_name)
- save_graph_to_file(sess, graph, intermediate_file_name)
-
# We've completed all our training, so run a final test evaluation on
# some new images we haven't used before.
test_bottlenecks, test_ground_truth, test_filenames = (
@@ -964,7 +940,10 @@ def main(_):
# Write out the trained graph and labels with the weights stored as
# constants.
- save_graph_to_file(sess, graph, FLAGS.output_graph)
+ output_graph_def = graph_util.convert_variables_to_constants(
+ sess, graph.as_graph_def(), [FLAGS.final_tensor_name])
+ with gfile.FastGFile(FLAGS.output_graph, 'wb') as f:
+ f.write(output_graph_def.SerializeToString())
with gfile.FastGFile(FLAGS.output_labels, 'w') as f:
f.write('\n'.join(image_lists.keys()) + '\n')
@@ -984,18 +963,6 @@ if __name__ == '__main__':
help='Where to save the trained graph.'
)
parser.add_argument(
- '--intermediate_output_graphs_dir',
- type=str,
- default='/tmp/intermediate_graph/',
- help='Where to save the intermediate graphs.'
- )
- parser.add_argument(
- '--intermediate_store_frequency',
- type=int,
- default=0,
- help='How many steps to store intermediate graph. If "0" then will not store.'
- )
- parser.add_argument(
'--output_labels',
type=str,
default='/tmp/output_labels.txt',