diff options
author | Benoit Steiner <bsteiner@google.com> | 2017-02-08 09:25:09 -0800 |
---|---|---|
committer | TensorFlower Gardener <gardener@tensorflow.org> | 2017-02-08 09:50:05 -0800 |
commit | 639b4e71f532761a4840b1cdbaea55ad0917c75b (patch) | |
tree | 5116415b1d9ff82f054dd4feeadd81cb833d6435 /tensorflow/examples/image_retraining | |
parent | 15ff7b702788c0cf75bb8d5ce090f06490098cf7 (diff) |
Merge changes from github.
Change: 146918929
Diffstat (limited to 'tensorflow/examples/image_retraining')
-rw-r--r-- | tensorflow/examples/image_retraining/retrain.py | 43 |
1 files changed, 26 insertions, 17 deletions
diff --git a/tensorflow/examples/image_retraining/retrain.py b/tensorflow/examples/image_retraining/retrain.py index c5518e2603..e612eb7424 100644 --- a/tensorflow/examples/image_retraining/retrain.py +++ b/tensorflow/examples/image_retraining/retrain.py @@ -346,6 +346,17 @@ def read_list_of_floats_from_file(file_path): bottleneck_path_2_bottleneck_values = {} +def create_bottleneck_file(bottleneck_path, image_lists, label_name, index, + image_dir, category, sess, jpeg_data_tensor, bottleneck_tensor): + print('Creating bottleneck at ' + bottleneck_path) + image_path = get_image_path(image_lists, label_name, index, image_dir, category) + if not gfile.Exists(image_path): + tf.logging.fatal('File does not exist %s', image_path) + image_data = gfile.FastGFile(image_path, 'rb').read() + bottleneck_values = run_bottleneck_on_image(sess, image_data, jpeg_data_tensor, bottleneck_tensor) + bottleneck_string = ','.join(str(x) for x in bottleneck_values) + with open(bottleneck_path, 'w') as bottleneck_file: + bottleneck_file.write(bottleneck_string) def get_or_create_bottleneck(sess, image_lists, label_name, index, image_dir, category, bottleneck_dir, jpeg_data_tensor, @@ -376,28 +387,25 @@ def get_or_create_bottleneck(sess, image_lists, label_name, index, image_dir, sub_dir = label_lists['dir'] sub_dir_path = os.path.join(bottleneck_dir, sub_dir) ensure_dir_exists(sub_dir_path) - bottleneck_path = get_bottleneck_path(image_lists, label_name, index, - bottleneck_dir, category) + bottleneck_path = get_bottleneck_path(image_lists, label_name, index, bottleneck_dir, category) if not os.path.exists(bottleneck_path): - print('Creating bottleneck at ' + bottleneck_path) - image_path = get_image_path(image_lists, label_name, index, image_dir, - category) - if not gfile.Exists(image_path): - tf.logging.fatal('File does not exist %s', image_path) - image_data = gfile.FastGFile(image_path, 'rb').read() - bottleneck_values = run_bottleneck_on_image(sess, image_data, - jpeg_data_tensor, - bottleneck_tensor) - bottleneck_string = ','.join(str(x) for x in bottleneck_values) - with open(bottleneck_path, 'w') as bottleneck_file: - bottleneck_file.write(bottleneck_string) - + create_bottleneck_file(bottleneck_path, image_lists, label_name, index, image_dir, category, sess, jpeg_data_tensor, bottleneck_tensor) with open(bottleneck_path, 'r') as bottleneck_file: bottleneck_string = bottleneck_file.read() - bottleneck_values = [float(x) for x in bottleneck_string.split(',')] + did_hit_error = False + try: + bottleneck_values = [float(x) for x in bottleneck_string.split(',')] + except: + print("Invalid float found, recreating bottleneck") + did_hit_error = True + if did_hit_error: + create_bottleneck_file(bottleneck_path, image_lists, label_name, index, image_dir, category, sess, jpeg_data_tensor, bottleneck_tensor) + with open(bottleneck_path, 'r') as bottleneck_file: + bottleneck_string = bottleneck_file.read() + # Allow exceptions to propagate here, since they shouldn't happen after a fresh creation + bottleneck_values = [float(x) for x in bottleneck_string.split(',')] return bottleneck_values - def cache_bottlenecks(sess, image_lists, image_dir, bottleneck_dir, jpeg_data_tensor, bottleneck_tensor): """Ensures all the training, testing, and validation bottlenecks are cached. @@ -430,6 +438,7 @@ def cache_bottlenecks(sess, image_lists, image_dir, bottleneck_dir, get_or_create_bottleneck(sess, image_lists, label_name, index, image_dir, category, bottleneck_dir, jpeg_data_tensor, bottleneck_tensor) + how_many_bottlenecks += 1 if how_many_bottlenecks % 100 == 0: print(str(how_many_bottlenecks) + ' bottleneck files created.') |