aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/examples/image_retraining
diff options
context:
space:
mode:
authorGravatar Andrew Harp <andrewharp@google.com>2016-09-13 23:13:19 -0800
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2016-09-14 00:17:57 -0700
commit68f1a7d05ce031192321a27daaf19529b4cb0a9b (patch)
tree326855ce87ba54990a99e40916c6da2e3bdb9309 /tensorflow/examples/image_retraining
parent16a39e515068d0c6141049ea508cab32a527498c (diff)
Merge changes from github.
Change: 133096559
Diffstat (limited to 'tensorflow/examples/image_retraining')
-rw-r--r--tensorflow/examples/image_retraining/retrain.py14
1 files changed, 10 insertions, 4 deletions
diff --git a/tensorflow/examples/image_retraining/retrain.py b/tensorflow/examples/image_retraining/retrain.py
index f226de0ce3..7812117e5d 100644
--- a/tensorflow/examples/image_retraining/retrain.py
+++ b/tensorflow/examples/image_retraining/retrain.py
@@ -165,6 +165,7 @@ MODEL_INPUT_HEIGHT = 299
MODEL_INPUT_DEPTH = 3
JPEG_DATA_TENSOR_NAME = 'DecodeJpeg/contents:0'
RESIZED_INPUT_TENSOR_NAME = 'ResizeBilinear:0'
+MAX_NUM_IMAGES_PER_CLASS = 2 ** 27 - 1 # ~134M
def create_image_lists(image_dir, testing_percentage, validation_percentage):
@@ -208,6 +209,9 @@ def create_image_lists(image_dir, testing_percentage, validation_percentage):
continue
if len(file_list) < 20:
print('WARNING: Folder has less than 20 images, which may cause issues.')
+ elif len(file_list) > MAX_NUM_IMAGES_PER_CLASS:
+ print('WARNING: Folder {} has more than {} images. Some images will '
+ 'never be selected.'.format(dir_name, MAX_NUM_IMAGES_PER_CLASS))
label_name = re.sub(r'[^a-z0-9]+', ' ', dir_name.lower())
training_images = []
testing_images = []
@@ -228,7 +232,9 @@ def create_image_lists(image_dir, testing_percentage, validation_percentage):
# itself, so we do a hash of that and then use that to generate a
# probability value that we use to assign it.
hash_name_hashed = hashlib.sha1(compat.as_bytes(hash_name)).hexdigest()
- percentage_hash = (int(hash_name_hashed, 16) % (65536)) * (100 / 65535.0)
+ percentage_hash = ((int(hash_name_hashed, 16) %
+ (MAX_NUM_IMAGES_PER_CLASS + 1)) *
+ (100.0 / MAX_NUM_IMAGES_PER_CLASS))
if percentage_hash < validation_percentage:
validation_images.append(base_name)
elif percentage_hash < (testing_percentage + validation_percentage):
@@ -323,7 +329,7 @@ def run_bottleneck_on_image(sess, image_data, image_data_tensor,
Args:
sess: Current active TensorFlow Session.
- image_data: Numpy array of image data.
+ image_data: String of raw JPEG data.
image_data_tensor: Input data layer in the graph.
bottleneck_tensor: Layer before the final softmax.
@@ -525,7 +531,7 @@ def get_random_cached_bottlenecks(sess, image_lists, how_many, category,
for unused_i in range(how_many):
label_index = random.randrange(class_count)
label_name = list(image_lists.keys())[label_index]
- image_index = random.randrange(65536)
+ image_index = random.randrange(MAX_NUM_IMAGES_PER_CLASS + 1)
bottleneck = get_or_create_bottleneck(sess, image_lists, label_name,
image_index, image_dir, category,
bottleneck_dir, jpeg_data_tensor,
@@ -570,7 +576,7 @@ def get_random_distorted_bottlenecks(
for unused_i in range(how_many):
label_index = random.randrange(class_count)
label_name = list(image_lists.keys())[label_index]
- image_index = random.randrange(65536)
+ image_index = random.randrange(MAX_NUM_IMAGES_PER_CLASS + 1)
image_path = get_image_path(image_lists, label_name, image_index, image_dir,
category)
if not gfile.Exists(image_path):