aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/examples/image_retraining/retrain.py
diff options
context:
space:
mode:
Diffstat (limited to 'tensorflow/examples/image_retraining/retrain.py')
-rw-r--r--tensorflow/examples/image_retraining/retrain.py43
1 files changed, 26 insertions, 17 deletions
diff --git a/tensorflow/examples/image_retraining/retrain.py b/tensorflow/examples/image_retraining/retrain.py
index c5518e2603..e612eb7424 100644
--- a/tensorflow/examples/image_retraining/retrain.py
+++ b/tensorflow/examples/image_retraining/retrain.py
@@ -346,6 +346,17 @@ def read_list_of_floats_from_file(file_path):
bottleneck_path_2_bottleneck_values = {}
+def create_bottleneck_file(bottleneck_path, image_lists, label_name, index,
+ image_dir, category, sess, jpeg_data_tensor, bottleneck_tensor):
+ print('Creating bottleneck at ' + bottleneck_path)
+ image_path = get_image_path(image_lists, label_name, index, image_dir, category)
+ if not gfile.Exists(image_path):
+ tf.logging.fatal('File does not exist %s', image_path)
+ image_data = gfile.FastGFile(image_path, 'rb').read()
+ bottleneck_values = run_bottleneck_on_image(sess, image_data, jpeg_data_tensor, bottleneck_tensor)
+ bottleneck_string = ','.join(str(x) for x in bottleneck_values)
+ with open(bottleneck_path, 'w') as bottleneck_file:
+ bottleneck_file.write(bottleneck_string)
def get_or_create_bottleneck(sess, image_lists, label_name, index, image_dir,
category, bottleneck_dir, jpeg_data_tensor,
@@ -376,28 +387,25 @@ def get_or_create_bottleneck(sess, image_lists, label_name, index, image_dir,
sub_dir = label_lists['dir']
sub_dir_path = os.path.join(bottleneck_dir, sub_dir)
ensure_dir_exists(sub_dir_path)
- bottleneck_path = get_bottleneck_path(image_lists, label_name, index,
- bottleneck_dir, category)
+ bottleneck_path = get_bottleneck_path(image_lists, label_name, index, bottleneck_dir, category)
if not os.path.exists(bottleneck_path):
- print('Creating bottleneck at ' + bottleneck_path)
- image_path = get_image_path(image_lists, label_name, index, image_dir,
- category)
- if not gfile.Exists(image_path):
- tf.logging.fatal('File does not exist %s', image_path)
- image_data = gfile.FastGFile(image_path, 'rb').read()
- bottleneck_values = run_bottleneck_on_image(sess, image_data,
- jpeg_data_tensor,
- bottleneck_tensor)
- bottleneck_string = ','.join(str(x) for x in bottleneck_values)
- with open(bottleneck_path, 'w') as bottleneck_file:
- bottleneck_file.write(bottleneck_string)
-
+ create_bottleneck_file(bottleneck_path, image_lists, label_name, index, image_dir, category, sess, jpeg_data_tensor, bottleneck_tensor)
with open(bottleneck_path, 'r') as bottleneck_file:
bottleneck_string = bottleneck_file.read()
- bottleneck_values = [float(x) for x in bottleneck_string.split(',')]
+ did_hit_error = False
+ try:
+ bottleneck_values = [float(x) for x in bottleneck_string.split(',')]
+ except:
+ print("Invalid float found, recreating bottleneck")
+ did_hit_error = True
+ if did_hit_error:
+ create_bottleneck_file(bottleneck_path, image_lists, label_name, index, image_dir, category, sess, jpeg_data_tensor, bottleneck_tensor)
+ with open(bottleneck_path, 'r') as bottleneck_file:
+ bottleneck_string = bottleneck_file.read()
+ # Allow exceptions to propagate here, since they shouldn't happen after a fresh creation
+ bottleneck_values = [float(x) for x in bottleneck_string.split(',')]
return bottleneck_values
-
def cache_bottlenecks(sess, image_lists, image_dir, bottleneck_dir,
jpeg_data_tensor, bottleneck_tensor):
"""Ensures all the training, testing, and validation bottlenecks are cached.
@@ -430,6 +438,7 @@ def cache_bottlenecks(sess, image_lists, image_dir, bottleneck_dir,
get_or_create_bottleneck(sess, image_lists, label_name, index,
image_dir, category, bottleneck_dir,
jpeg_data_tensor, bottleneck_tensor)
+
how_many_bottlenecks += 1
if how_many_bottlenecks % 100 == 0:
print(str(how_many_bottlenecks) + ' bottleneck files created.')