diff options
author | 2016-09-13 23:13:19 -0800 | |
---|---|---|
committer | 2016-09-14 00:17:57 -0700 | |
commit | 68f1a7d05ce031192321a27daaf19529b4cb0a9b (patch) | |
tree | 326855ce87ba54990a99e40916c6da2e3bdb9309 /tensorflow/examples/image_retraining | |
parent | 16a39e515068d0c6141049ea508cab32a527498c (diff) |
Merge changes from github.
Change: 133096559
Diffstat (limited to 'tensorflow/examples/image_retraining')
-rw-r--r-- | tensorflow/examples/image_retraining/retrain.py | 14 |
1 files changed, 10 insertions, 4 deletions
diff --git a/tensorflow/examples/image_retraining/retrain.py b/tensorflow/examples/image_retraining/retrain.py index f226de0ce3..7812117e5d 100644 --- a/tensorflow/examples/image_retraining/retrain.py +++ b/tensorflow/examples/image_retraining/retrain.py @@ -165,6 +165,7 @@ MODEL_INPUT_HEIGHT = 299 MODEL_INPUT_DEPTH = 3 JPEG_DATA_TENSOR_NAME = 'DecodeJpeg/contents:0' RESIZED_INPUT_TENSOR_NAME = 'ResizeBilinear:0' +MAX_NUM_IMAGES_PER_CLASS = 2 ** 27 - 1 # ~134M def create_image_lists(image_dir, testing_percentage, validation_percentage): @@ -208,6 +209,9 @@ def create_image_lists(image_dir, testing_percentage, validation_percentage): continue if len(file_list) < 20: print('WARNING: Folder has less than 20 images, which may cause issues.') + elif len(file_list) > MAX_NUM_IMAGES_PER_CLASS: + print('WARNING: Folder {} has more than {} images. Some images will ' + 'never be selected.'.format(dir_name, MAX_NUM_IMAGES_PER_CLASS)) label_name = re.sub(r'[^a-z0-9]+', ' ', dir_name.lower()) training_images = [] testing_images = [] @@ -228,7 +232,9 @@ def create_image_lists(image_dir, testing_percentage, validation_percentage): # itself, so we do a hash of that and then use that to generate a # probability value that we use to assign it. hash_name_hashed = hashlib.sha1(compat.as_bytes(hash_name)).hexdigest() - percentage_hash = (int(hash_name_hashed, 16) % (65536)) * (100 / 65535.0) + percentage_hash = ((int(hash_name_hashed, 16) % + (MAX_NUM_IMAGES_PER_CLASS + 1)) * + (100.0 / MAX_NUM_IMAGES_PER_CLASS)) if percentage_hash < validation_percentage: validation_images.append(base_name) elif percentage_hash < (testing_percentage + validation_percentage): @@ -323,7 +329,7 @@ def run_bottleneck_on_image(sess, image_data, image_data_tensor, Args: sess: Current active TensorFlow Session. - image_data: Numpy array of image data. + image_data: String of raw JPEG data. image_data_tensor: Input data layer in the graph. bottleneck_tensor: Layer before the final softmax. @@ -525,7 +531,7 @@ def get_random_cached_bottlenecks(sess, image_lists, how_many, category, for unused_i in range(how_many): label_index = random.randrange(class_count) label_name = list(image_lists.keys())[label_index] - image_index = random.randrange(65536) + image_index = random.randrange(MAX_NUM_IMAGES_PER_CLASS + 1) bottleneck = get_or_create_bottleneck(sess, image_lists, label_name, image_index, image_dir, category, bottleneck_dir, jpeg_data_tensor, @@ -570,7 +576,7 @@ def get_random_distorted_bottlenecks( for unused_i in range(how_many): label_index = random.randrange(class_count) label_name = list(image_lists.keys())[label_index] - image_index = random.randrange(65536) + image_index = random.randrange(MAX_NUM_IMAGES_PER_CLASS + 1) image_path = get_image_path(image_lists, label_name, image_index, image_dir, category) if not gfile.Exists(image_path): |