aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/examples/image_retraining
diff options
context:
space:
mode:
authorGravatar Pete Warden <petewarden@google.com>2017-07-05 13:12:07 -0700
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2017-07-05 13:20:42 -0700
commit055500bbcea60513c0160d213a10a7055f079312 (patch)
tree794a2f2f3340d80070c9ce743ec2858bc7f36a09 /tensorflow/examples/image_retraining
parente4351ac3f8d4c4134dbd90ec9f8f03fdef31c20a (diff)
Added Mobilenet support to TensorFlow for Poets training script.
PiperOrigin-RevId: 160995543
Diffstat (limited to 'tensorflow/examples/image_retraining')
-rw-r--r--tensorflow/examples/image_retraining/retrain.py480
-rw-r--r--tensorflow/examples/image_retraining/retrain_test.py29
2 files changed, 357 insertions, 152 deletions
diff --git a/tensorflow/examples/image_retraining/retrain.py b/tensorflow/examples/image_retraining/retrain.py
index 44a3097d80..ac39639af1 100644
--- a/tensorflow/examples/image_retraining/retrain.py
+++ b/tensorflow/examples/image_retraining/retrain.py
@@ -12,18 +12,19 @@
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
-"""Simple transfer learning with an Inception v3 architecture model.
+r"""Simple transfer learning with Inception v3 or Mobilenet models.
With support for TensorBoard.
-This example shows how to take a Inception v3 architecture model trained on
+This example shows how to take a Inception v3 or Mobilenet model trained on
ImageNet images, and train a new top layer that can recognize other classes of
images.
-The top layer receives as input a 2048-dimensional vector for each image. We
-train a softmax layer on top of this representation. Assuming the softmax layer
-contains N labels, this corresponds to learning N + 2048*N model parameters
-corresponding to the learned biases and weights.
+The top layer receives as input a 2048-dimensional vector (1001-dimensional for
+Mobilenet) for each image. We train a softmax layer on top of this
+representation. Assuming the softmax layer contains N labels, this corresponds
+to learning N + 2048*N (or 1001*N) model parameters corresponding to the
+learned biases and weights.
Here's an example, which assumes you have a folder containing class-named
subfolders, each full of images for each label. The example folder flower_photos
@@ -62,6 +63,23 @@ in.
This produces a new model file that can be loaded and run by any TensorFlow
program, for example the label_image sample code.
+By default this script will use the high accuracy, but comparatively large and
+slow Inception v3 model architecture. It's recommended that you start with this
+to validate that you have gathered good training data, but if you want to deploy
+on resource-limited platforms, you can try the `--architecture` flag with a
+Mobilenet model. For example:
+
+```bash
+python tensorflow/examples/image_retraining/retrain.py \
+ --image_dir ~/flower_photos --architecture mobilenet_1.0_224
+```
+
+There are 32 different Mobilenet models to choose from, with a variety of file
+size and latency options. The first number can be '1.0', '0.75', '0.50', or
+'0.25' to control the size, and the second controls the input image size, either
+'224', '192', '160', or '128', with smaller sizes running faster. See
+https://research.googleblog.com/2017/06/mobilenets-open-source-models-for.html
+for more information on Mobilenet.
To use with TensorBoard:
@@ -82,7 +100,6 @@ import hashlib
import os.path
import random
import re
-import struct
import sys
import tarfile
@@ -101,16 +118,6 @@ FLAGS = None
# we're using for Inception v3. These include things like tensor names and their
# sizes. If you want to adapt this script to work with another model, you will
# need to update these to reflect the values in the network you're using.
-# pylint: disable=line-too-long
-DATA_URL = 'http://download.tensorflow.org/models/image/imagenet/inception-2015-12-05.tgz'
-# pylint: enable=line-too-long
-BOTTLENECK_TENSOR_NAME = 'pool_3/_reshape:0'
-BOTTLENECK_TENSOR_SIZE = 2048
-MODEL_INPUT_WIDTH = 299
-MODEL_INPUT_HEIGHT = 299
-MODEL_INPUT_DEPTH = 3
-JPEG_DATA_TENSOR_NAME = 'DecodeJpeg/contents:0'
-RESIZED_INPUT_TENSOR_NAME = 'ResizeBilinear:0'
MAX_NUM_IMAGES_PER_CLASS = 2 ** 27 - 1 # ~134M
@@ -131,7 +138,7 @@ def create_image_lists(image_dir, testing_percentage, validation_percentage):
into training, testing, and validation sets within each label.
"""
if not gfile.Exists(image_dir):
- print("Image directory '" + image_dir + "' not found.")
+ tf.logging.error("Image directory '" + image_dir + "' not found.")
return None
result = {}
sub_dirs = [x[0] for x in gfile.Walk(image_dir)]
@@ -146,18 +153,20 @@ def create_image_lists(image_dir, testing_percentage, validation_percentage):
dir_name = os.path.basename(sub_dir)
if dir_name == image_dir:
continue
- print("Looking for images in '" + dir_name + "'")
+ tf.logging.info("Looking for images in '" + dir_name + "'")
for extension in extensions:
file_glob = os.path.join(image_dir, dir_name, '*.' + extension)
file_list.extend(gfile.Glob(file_glob))
if not file_list:
- print('No files found')
+ tf.logging.warning('No files found')
continue
if len(file_list) < 20:
- print('WARNING: Folder has less than 20 images, which may cause issues.')
+ tf.logging.warning(
+ 'WARNING: Folder has less than 20 images, which may cause issues.')
elif len(file_list) > MAX_NUM_IMAGES_PER_CLASS:
- print('WARNING: Folder {} has more than {} images. Some images will '
- 'never be selected.'.format(dir_name, MAX_NUM_IMAGES_PER_CLASS))
+ tf.logging.warning(
+ 'WARNING: Folder {} has more than {} images. Some images will '
+ 'never be selected.'.format(dir_name, MAX_NUM_IMAGES_PER_CLASS))
label_name = re.sub(r'[^a-z0-9]+', ' ', dir_name.lower())
training_images = []
testing_images = []
@@ -230,7 +239,7 @@ def get_image_path(image_lists, label_name, index, image_dir, category):
def get_bottleneck_path(image_lists, label_name, index, bottleneck_dir,
- category):
+ category, architecture):
""""Returns a path to a bottleneck file for a label at the given index.
Args:
@@ -241,35 +250,42 @@ def get_bottleneck_path(image_lists, label_name, index, bottleneck_dir,
bottleneck_dir: Folder string holding cached files of bottleneck values.
category: Name string of set to pull images from - training, testing, or
validation.
+ architecture: The name of the model architecture.
Returns:
File system path string to an image that meets the requested parameters.
"""
return get_image_path(image_lists, label_name, index, bottleneck_dir,
- category) + '.txt'
+ category) + '_' + architecture + '.txt'
-def create_inception_graph():
+def create_model_graph(model_info):
""""Creates a graph from saved GraphDef file and returns a Graph object.
+ Args:
+ model_info: Dictionary containing information about the model architecture.
+
Returns:
Graph holding the trained Inception network, and various tensors we'll be
manipulating.
"""
with tf.Graph().as_default() as graph:
- model_filename = os.path.join(
- FLAGS.model_dir, 'classify_image_graph_def.pb')
- with gfile.FastGFile(model_filename, 'rb') as f:
+ model_path = os.path.join(FLAGS.model_dir, model_info['model_file_name'])
+ with gfile.FastGFile(model_path, 'rb') as f:
graph_def = tf.GraphDef()
graph_def.ParseFromString(f.read())
- bottleneck_tensor, jpeg_data_tensor, resized_input_tensor = (
- tf.import_graph_def(graph_def, name='', return_elements=[
- BOTTLENECK_TENSOR_NAME, JPEG_DATA_TENSOR_NAME,
- RESIZED_INPUT_TENSOR_NAME]))
- return graph, bottleneck_tensor, jpeg_data_tensor, resized_input_tensor
+ bottleneck_tensor, resized_input_tensor = (tf.import_graph_def(
+ graph_def,
+ name='',
+ return_elements=[
+ model_info['bottleneck_tensor_name'],
+ model_info['resized_input_tensor_name'],
+ ]))
+ return graph, bottleneck_tensor, resized_input_tensor
def run_bottleneck_on_image(sess, image_data, image_data_tensor,
+ decoded_image_tensor, resized_input_tensor,
bottleneck_tensor):
"""Runs inference on an image to extract the 'bottleneck' summary layer.
@@ -277,28 +293,36 @@ def run_bottleneck_on_image(sess, image_data, image_data_tensor,
sess: Current active TensorFlow Session.
image_data: String of raw JPEG data.
image_data_tensor: Input data layer in the graph.
+ decoded_image_tensor: Output of initial image resizing and preprocessing.
+ resized_input_tensor: The input node of the recognition graph.
bottleneck_tensor: Layer before the final softmax.
Returns:
Numpy array of bottleneck values.
"""
- bottleneck_values = sess.run(
- bottleneck_tensor,
- {image_data_tensor: image_data})
+ # First decode the JPEG image, resize it, and rescale the pixel values.
+ resized_input_values = sess.run(decoded_image_tensor,
+ {image_data_tensor: image_data})
+ # Then run it through the recognition network.
+ bottleneck_values = sess.run(bottleneck_tensor,
+ {resized_input_tensor: resized_input_values})
bottleneck_values = np.squeeze(bottleneck_values)
return bottleneck_values
-def maybe_download_and_extract():
+def maybe_download_and_extract(data_url):
"""Download and extract model tar file.
If the pretrained model we're using doesn't already exist, this function
downloads it from the TensorFlow.org website and unpacks it into a directory.
+
+ Args:
+ data_url: Web location of the tar file containing the pretrained model.
"""
dest_directory = FLAGS.model_dir
if not os.path.exists(dest_directory):
os.makedirs(dest_directory)
- filename = DATA_URL.split('/')[-1]
+ filename = data_url.split('/')[-1]
filepath = os.path.join(dest_directory, filename)
if not os.path.exists(filepath):
@@ -308,12 +332,11 @@ def maybe_download_and_extract():
float(count * block_size) / float(total_size) * 100.0))
sys.stdout.flush()
- filepath, _ = urllib.request.urlretrieve(DATA_URL,
- filepath,
- _progress)
+ filepath, _ = urllib.request.urlretrieve(data_url, filepath, _progress)
print()
statinfo = os.stat(filepath)
- print('Successfully downloaded', filename, statinfo.st_size, 'bytes.')
+ tf.logging.info('Successfully downloaded', filename, statinfo.st_size,
+ 'bytes.')
tarfile.open(filepath, 'r:gz').extractall(dest_directory)
@@ -327,43 +350,15 @@ def ensure_dir_exists(dir_name):
os.makedirs(dir_name)
-def write_list_of_floats_to_file(list_of_floats, file_path):
- """Writes a given list of floats to a binary file.
-
- Args:
- list_of_floats: List of floats we want to write to a file.
- file_path: Path to a file where list of floats will be stored.
-
- """
-
- s = struct.pack('d' * BOTTLENECK_TENSOR_SIZE, *list_of_floats)
- with open(file_path, 'wb') as f:
- f.write(s)
-
-
-def read_list_of_floats_from_file(file_path):
- """Reads list of floats from a given file.
-
- Args:
- file_path: Path to a file where list of floats was stored.
- Returns:
- Array of bottleneck values (list of floats).
-
- """
-
- with open(file_path, 'rb') as f:
- s = struct.unpack('d' * BOTTLENECK_TENSOR_SIZE, f.read())
- return list(s)
-
-
bottleneck_path_2_bottleneck_values = {}
def create_bottleneck_file(bottleneck_path, image_lists, label_name, index,
image_dir, category, sess, jpeg_data_tensor,
+ decoded_image_tensor, resized_input_tensor,
bottleneck_tensor):
"""Create a single bottleneck file."""
- print('Creating bottleneck at ' + bottleneck_path)
+ tf.logging.info('Creating bottleneck at ' + bottleneck_path)
image_path = get_image_path(image_lists, label_name, index,
image_dir, category)
if not gfile.Exists(image_path):
@@ -371,10 +366,11 @@ def create_bottleneck_file(bottleneck_path, image_lists, label_name, index,
image_data = gfile.FastGFile(image_path, 'rb').read()
try:
bottleneck_values = run_bottleneck_on_image(
- sess, image_data, jpeg_data_tensor, bottleneck_tensor)
- except:
- raise RuntimeError('Error during processing file %s' % image_path)
-
+ sess, image_data, jpeg_data_tensor, decoded_image_tensor,
+ resized_input_tensor, bottleneck_tensor)
+ except Exception as e:
+ raise RuntimeError('Error during processing file %s (%s)' % (image_path,
+ str(e)))
bottleneck_string = ','.join(str(x) for x in bottleneck_values)
with open(bottleneck_path, 'w') as bottleneck_file:
bottleneck_file.write(bottleneck_string)
@@ -382,7 +378,8 @@ def create_bottleneck_file(bottleneck_path, image_lists, label_name, index,
def get_or_create_bottleneck(sess, image_lists, label_name, index, image_dir,
category, bottleneck_dir, jpeg_data_tensor,
- bottleneck_tensor):
+ decoded_image_tensor, resized_input_tensor,
+ bottleneck_tensor, architecture):
"""Retrieves or calculates bottleneck values for an image.
If a cached version of the bottleneck data exists on-disk, return that,
@@ -400,7 +397,10 @@ def get_or_create_bottleneck(sess, image_lists, label_name, index, image_dir,
or validation.
bottleneck_dir: Folder string holding cached files of bottleneck values.
jpeg_data_tensor: The tensor to feed loaded jpeg data into.
+ decoded_image_tensor: The output of decoding and resizing the image.
+ resized_input_tensor: The input node of the recognition graph.
bottleneck_tensor: The output tensor for the bottleneck values.
+ architecture: The name of the model architecture.
Returns:
Numpy array of values produced by the bottleneck layer for the image.
@@ -410,10 +410,11 @@ def get_or_create_bottleneck(sess, image_lists, label_name, index, image_dir,
sub_dir_path = os.path.join(bottleneck_dir, sub_dir)
ensure_dir_exists(sub_dir_path)
bottleneck_path = get_bottleneck_path(image_lists, label_name, index,
- bottleneck_dir, category)
+ bottleneck_dir, category, architecture)
if not os.path.exists(bottleneck_path):
create_bottleneck_file(bottleneck_path, image_lists, label_name, index,
image_dir, category, sess, jpeg_data_tensor,
+ decoded_image_tensor, resized_input_tensor,
bottleneck_tensor)
with open(bottleneck_path, 'r') as bottleneck_file:
bottleneck_string = bottleneck_file.read()
@@ -421,11 +422,12 @@ def get_or_create_bottleneck(sess, image_lists, label_name, index, image_dir,
try:
bottleneck_values = [float(x) for x in bottleneck_string.split(',')]
except ValueError:
- print('Invalid float found, recreating bottleneck')
+ tf.logging.warning('Invalid float found, recreating bottleneck')
did_hit_error = True
if did_hit_error:
create_bottleneck_file(bottleneck_path, image_lists, label_name, index,
image_dir, category, sess, jpeg_data_tensor,
+ decoded_image_tensor, resized_input_tensor,
bottleneck_tensor)
with open(bottleneck_path, 'r') as bottleneck_file:
bottleneck_string = bottleneck_file.read()
@@ -436,7 +438,8 @@ def get_or_create_bottleneck(sess, image_lists, label_name, index, image_dir,
def cache_bottlenecks(sess, image_lists, image_dir, bottleneck_dir,
- jpeg_data_tensor, bottleneck_tensor):
+ jpeg_data_tensor, decoded_image_tensor,
+ resized_input_tensor, bottleneck_tensor, architecture):
"""Ensures all the training, testing, and validation bottlenecks are cached.
Because we're likely to read the same image multiple times (if there are no
@@ -453,7 +456,10 @@ def cache_bottlenecks(sess, image_lists, image_dir, bottleneck_dir,
images.
bottleneck_dir: Folder string holding cached files of bottleneck values.
jpeg_data_tensor: Input tensor for jpeg data from file.
+ decoded_image_tensor: The output of decoding and resizing the image.
+ resized_input_tensor: The input node of the recognition graph.
bottleneck_tensor: The penultimate output layer of the graph.
+ architecture: The name of the model architecture.
Returns:
Nothing.
@@ -464,18 +470,21 @@ def cache_bottlenecks(sess, image_lists, image_dir, bottleneck_dir,
for category in ['training', 'testing', 'validation']:
category_list = label_lists[category]
for index, unused_base_name in enumerate(category_list):
- get_or_create_bottleneck(sess, image_lists, label_name, index,
- image_dir, category, bottleneck_dir,
- jpeg_data_tensor, bottleneck_tensor)
+ get_or_create_bottleneck(
+ sess, image_lists, label_name, index, image_dir, category,
+ bottleneck_dir, jpeg_data_tensor, decoded_image_tensor,
+ resized_input_tensor, bottleneck_tensor, architecture)
how_many_bottlenecks += 1
if how_many_bottlenecks % 100 == 0:
- print(str(how_many_bottlenecks) + ' bottleneck files created.')
+ tf.logging.info(
+ str(how_many_bottlenecks) + ' bottleneck files created.')
def get_random_cached_bottlenecks(sess, image_lists, how_many, category,
bottleneck_dir, image_dir, jpeg_data_tensor,
- bottleneck_tensor):
+ decoded_image_tensor, resized_input_tensor,
+ bottleneck_tensor, architecture):
"""Retrieves bottleneck values for cached images.
If no distortions are being applied, this function can retrieve the cached
@@ -493,7 +502,10 @@ def get_random_cached_bottlenecks(sess, image_lists, how_many, category,
image_dir: Root folder string of the subfolders containing the training
images.
jpeg_data_tensor: The layer to feed jpeg image data into.
+ decoded_image_tensor: The output of decoding and resizing the image.
+ resized_input_tensor: The input node of the recognition graph.
bottleneck_tensor: The bottleneck output layer of the CNN graph.
+ architecture: The name of the model architecture.
Returns:
List of bottleneck arrays, their corresponding ground truths, and the
@@ -511,10 +523,10 @@ def get_random_cached_bottlenecks(sess, image_lists, how_many, category,
image_index = random.randrange(MAX_NUM_IMAGES_PER_CLASS + 1)
image_name = get_image_path(image_lists, label_name, image_index,
image_dir, category)
- bottleneck = get_or_create_bottleneck(sess, image_lists, label_name,
- image_index, image_dir, category,
- bottleneck_dir, jpeg_data_tensor,
- bottleneck_tensor)
+ bottleneck = get_or_create_bottleneck(
+ sess, image_lists, label_name, image_index, image_dir, category,
+ bottleneck_dir, jpeg_data_tensor, decoded_image_tensor,
+ resized_input_tensor, bottleneck_tensor, architecture)
ground_truth = np.zeros(class_count, dtype=np.float32)
ground_truth[label_index] = 1.0
bottlenecks.append(bottleneck)
@@ -527,10 +539,10 @@ def get_random_cached_bottlenecks(sess, image_lists, how_many, category,
image_lists[label_name][category]):
image_name = get_image_path(image_lists, label_name, image_index,
image_dir, category)
- bottleneck = get_or_create_bottleneck(sess, image_lists, label_name,
- image_index, image_dir, category,
- bottleneck_dir, jpeg_data_tensor,
- bottleneck_tensor)
+ bottleneck = get_or_create_bottleneck(
+ sess, image_lists, label_name, image_index, image_dir, category,
+ bottleneck_dir, jpeg_data_tensor, decoded_image_tensor,
+ resized_input_tensor, bottleneck_tensor, architecture)
ground_truth = np.zeros(class_count, dtype=np.float32)
ground_truth[label_index] = 1.0
bottlenecks.append(bottleneck)
@@ -583,12 +595,12 @@ def get_random_distorted_bottlenecks(
# might be optimized in other implementations.
distorted_image_data = sess.run(distorted_image,
{input_jpeg_tensor: jpeg_data})
- bottleneck = run_bottleneck_on_image(sess, distorted_image_data,
- resized_input_tensor,
- bottleneck_tensor)
+ bottleneck_values = sess.run(bottleneck_tensor,
+ {resized_input_tensor: distorted_image_data})
+ bottleneck_values = np.squeeze(bottleneck_values)
ground_truth = np.zeros(class_count, dtype=np.float32)
ground_truth[label_index] = 1.0
- bottlenecks.append(bottleneck)
+ bottlenecks.append(bottleneck_values)
ground_truths.append(ground_truth)
return bottlenecks, ground_truths
@@ -612,7 +624,8 @@ def should_distort_images(flip_left_right, random_crop, random_scale,
def add_input_distortions(flip_left_right, random_crop, random_scale,
- random_brightness):
+ random_brightness, input_width, input_height,
+ input_depth, input_mean, input_std):
"""Creates the operations to apply the specified distortions.
During training it can help to improve the results if we run the images
@@ -660,13 +673,18 @@ def add_input_distortions(flip_left_right, random_crop, random_scale,
random_scale: Integer percentage of how much to vary the scale by.
random_brightness: Integer range to randomly multiply the pixel values by.
graph.
+ input_width: Horizontal size of expected input image to model.
+ input_height: Vertical size of expected input image to model.
+ input_depth: How many channels the expected input image should have.
+ input_mean: Pixel value that should be zero in the image for the graph.
+ input_std: How much to divide the pixel values by before recognition.
Returns:
The jpeg input layer and the distorted result tensor.
"""
jpeg_data = tf.placeholder(tf.string, name='DistortJPGInput')
- decoded_image = tf.image.decode_jpeg(jpeg_data, channels=MODEL_INPUT_DEPTH)
+ decoded_image = tf.image.decode_jpeg(jpeg_data, channels=input_depth)
decoded_image_as_float = tf.cast(decoded_image, dtype=tf.float32)
decoded_image_4d = tf.expand_dims(decoded_image_as_float, 0)
margin_scale = 1.0 + (random_crop / 100.0)
@@ -676,16 +694,15 @@ def add_input_distortions(flip_left_right, random_crop, random_scale,
minval=1.0,
maxval=resize_scale)
scale_value = tf.multiply(margin_scale_value, resize_scale_value)
- precrop_width = tf.multiply(scale_value, MODEL_INPUT_WIDTH)
- precrop_height = tf.multiply(scale_value, MODEL_INPUT_HEIGHT)
+ precrop_width = tf.multiply(scale_value, input_width)
+ precrop_height = tf.multiply(scale_value, input_height)
precrop_shape = tf.stack([precrop_height, precrop_width])
precrop_shape_as_int = tf.cast(precrop_shape, dtype=tf.int32)
precropped_image = tf.image.resize_bilinear(decoded_image_4d,
precrop_shape_as_int)
precropped_image_3d = tf.squeeze(precropped_image, squeeze_dims=[0])
cropped_image = tf.random_crop(precropped_image_3d,
- [MODEL_INPUT_HEIGHT, MODEL_INPUT_WIDTH,
- MODEL_INPUT_DEPTH])
+ [input_height, input_width, input_depth])
if flip_left_right:
flipped_image = tf.image.random_flip_left_right(cropped_image)
else:
@@ -696,7 +713,9 @@ def add_input_distortions(flip_left_right, random_crop, random_scale,
minval=brightness_min,
maxval=brightness_max)
brightened_image = tf.multiply(flipped_image, brightness_value)
- distort_result = tf.expand_dims(brightened_image, 0, name='DistortResult')
+ offset_image = tf.subtract(brightened_image, input_mean)
+ mul_image = tf.multiply(offset_image, 1.0 / input_std)
+ distort_result = tf.expand_dims(mul_image, 0, name='DistortResult')
return jpeg_data, distort_result
@@ -713,7 +732,8 @@ def variable_summaries(var):
tf.summary.histogram('histogram', var)
-def add_final_training_ops(class_count, final_tensor_name, bottleneck_tensor):
+def add_final_training_ops(class_count, final_tensor_name, bottleneck_tensor,
+ bottleneck_tensor_size):
"""Adds a new softmax and fully-connected layer for training.
We need to retrain the top layer to identify our new classes, so this function
@@ -728,6 +748,7 @@ def add_final_training_ops(class_count, final_tensor_name, bottleneck_tensor):
recognize.
final_tensor_name: Name string for the new final node that produces results.
bottleneck_tensor: The output of the main CNN graph.
+ bottleneck_tensor_size: How many entries in the bottleneck vector.
Returns:
The tensors for the training and cross entropy results, and tensors for the
@@ -735,7 +756,8 @@ def add_final_training_ops(class_count, final_tensor_name, bottleneck_tensor):
"""
with tf.name_scope('input'):
bottleneck_input = tf.placeholder_with_default(
- bottleneck_tensor, shape=[None, BOTTLENECK_TENSOR_SIZE],
+ bottleneck_tensor,
+ shape=[None, bottleneck_tensor_size],
name='BottleneckInputPlaceholder')
ground_truth_input = tf.placeholder(tf.float32,
@@ -747,8 +769,8 @@ def add_final_training_ops(class_count, final_tensor_name, bottleneck_tensor):
layer_name = 'final_training_ops'
with tf.name_scope(layer_name):
with tf.name_scope('weights'):
- initial_value = tf.truncated_normal([BOTTLENECK_TENSOR_SIZE, class_count],
- stddev=0.001)
+ initial_value = tf.truncated_normal(
+ [bottleneck_tensor_size, class_count], stddev=0.001)
layer_weights = tf.Variable(initial_value, name='final_weights')
@@ -802,7 +824,7 @@ def add_evaluation_step(result_tensor, ground_truth_tensor):
def save_graph_to_file(sess, graph, graph_file_name):
output_graph_def = graph_util.convert_variables_to_constants(
- sess, graph.as_graph_def(), [FLAGS.final_tensor_name])
+ sess, graph.as_graph_def(), [FLAGS.final_tensor_name])
with gfile.FastGFile(graph_file_name, 'wb') as f:
f.write(output_graph_def.SerializeToString())
return
@@ -818,25 +840,160 @@ def prepare_file_system():
return
+def create_model_info(architecture):
+ """Given the name of a model architecture, returns information about it.
+
+ There are different base image recognition pretrained models that can be
+ retrained using transfer learning, and this function translates from the name
+ of a model to the attributes that are needed to download and train with it.
+
+ Args:
+ architecture: Name of a model architecture.
+
+ Returns:
+ Dictionary of information about the model, or None if the name isn't
+ recognized
+
+ Raises:
+ ValueError: If architecture name is unknown.
+ """
+ architecture = architecture.lower()
+ if architecture == 'inception_v3':
+ # pylint: disable=line-too-long
+ data_url = 'http://download.tensorflow.org/models/image/imagenet/inception-2015-12-05.tgz'
+ # pylint: enable=line-too-long
+ bottleneck_tensor_name = 'pool_3/_reshape:0'
+ bottleneck_tensor_size = 2048
+ input_width = 299
+ input_height = 299
+ input_depth = 3
+ resized_input_tensor_name = 'Mul:0'
+ model_file_name = 'classify_image_graph_def.pb'
+ input_mean = 128
+ input_std = 128
+ elif architecture.startswith('mobilenet_'):
+ parts = architecture.split('_')
+ if len(parts) != 3 and len(parts) != 4:
+ tf.logging.error("Couldn't understand architecture name '%s'",
+ architecture)
+ return None
+ version_string = parts[1]
+ if (version_string != '1.0' and version_string != '0.75' and
+ version_string != '0.50' and version_string != '0.25'):
+ tf.logging.error(
+ """"The Mobilenet version should be '1.0', '0.75', '0.50', or '0.25',
+ but found '%s' for architecture '%s'""",
+ version_string, architecture)
+ return None
+ size_string = parts[2]
+ if (size_string != '224' and size_string != '192' and
+ size_string != '160' and size_string != '128'):
+ tf.logging.error(
+ """The Mobilenet input size should be '224', '192', '160', or '128',
+ but found '%s' for architecture '%s'""",
+ size_string, architecture)
+ return None
+ if len(parts) == 3:
+ is_quantized = False
+ else:
+ if parts[3] != 'quantized':
+ tf.logging.error(
+ "Couldn't understand architecture suffix '%s' for '%s'", parts[3],
+ architecture)
+ return None
+ is_quantized = True
+ data_url = 'http://download.tensorflow.org/models/mobilenet_v1_'
+ data_url += version_string + '_' + size_string + '_frozen.tgz'
+ bottleneck_tensor_name = 'MobilenetV1/Predictions/Reshape:0'
+ bottleneck_tensor_size = 1001
+ input_width = int(size_string)
+ input_height = int(size_string)
+ input_depth = 3
+ resized_input_tensor_name = 'input:0'
+ if is_quantized:
+ model_base_name = 'quantized_graph.pb'
+ else:
+ model_base_name = 'frozen_graph.pb'
+ model_dir_name = 'mobilenet_v1_' + version_string + '_' + size_string
+ model_file_name = os.path.join(model_dir_name, model_base_name)
+ input_mean = 127.5
+ input_std = 127.5
+ else:
+ tf.logging.error("Couldn't understand architecture name '%s'", architecture)
+ raise ValueError('Unknown architecture', architecture)
+
+ return {
+ 'data_url': data_url,
+ 'bottleneck_tensor_name': bottleneck_tensor_name,
+ 'bottleneck_tensor_size': bottleneck_tensor_size,
+ 'input_width': input_width,
+ 'input_height': input_height,
+ 'input_depth': input_depth,
+ 'resized_input_tensor_name': resized_input_tensor_name,
+ 'model_file_name': model_file_name,
+ 'input_mean': input_mean,
+ 'input_std': input_std,
+ }
+
+
+def add_jpeg_decoding(input_width, input_height, input_depth, input_mean,
+ input_std):
+ """Adds operations that perform JPEG decoding and resizing to the graph..
+
+ Args:
+ input_width: Desired width of the image fed into the recognizer graph.
+ input_height: Desired width of the image fed into the recognizer graph.
+ input_depth: Desired channels of the image fed into the recognizer graph.
+ input_mean: Pixel value that should be zero in the image for the graph.
+ input_std: How much to divide the pixel values by before recognition.
+
+ Returns:
+ Tensors for the node to feed JPEG data into, and the output of the
+ preprocessing steps.
+ """
+ jpeg_data = tf.placeholder(tf.string, name='DecodeJPGInput')
+ decoded_image = tf.image.decode_jpeg(jpeg_data, channels=input_depth)
+ decoded_image_as_float = tf.cast(decoded_image, dtype=tf.float32)
+ decoded_image_4d = tf.expand_dims(decoded_image_as_float, 0)
+ resize_shape = tf.stack([input_height, input_width])
+ resize_shape_as_int = tf.cast(resize_shape, dtype=tf.int32)
+ resized_image = tf.image.resize_bilinear(decoded_image_4d,
+ resize_shape_as_int)
+ offset_image = tf.subtract(resized_image, input_mean)
+ mul_image = tf.multiply(offset_image, 1.0 / input_std)
+ return jpeg_data, mul_image
+
+
def main(_):
+ # Needed to make sure the logging output is visible.
+ # See https://github.com/tensorflow/tensorflow/issues/3047
+ tf.logging.set_verbosity(tf.logging.INFO)
+
# Prepare necessary directories that can be used during training
prepare_file_system()
+ # Gather information about the model architecture we'll be using.
+ model_info = create_model_info(FLAGS.architecture)
+ if not model_info:
+ tf.logging.error('Did not recognize architecture flag')
+ return -1
+
# Set up the pre-trained graph.
- maybe_download_and_extract()
- graph, bottleneck_tensor, jpeg_data_tensor, resized_image_tensor = (
- create_inception_graph())
+ maybe_download_and_extract(model_info['data_url'])
+ graph, bottleneck_tensor, resized_image_tensor = (
+ create_model_graph(model_info))
# Look at the folder structure, and create lists of all the images.
image_lists = create_image_lists(FLAGS.image_dir, FLAGS.testing_percentage,
FLAGS.validation_percentage)
class_count = len(image_lists.keys())
if class_count == 0:
- print('No valid folders of images found at ' + FLAGS.image_dir)
+ tf.logging.error('No valid folders of images found at ' + FLAGS.image_dir)
return -1
if class_count == 1:
- print('Only one valid folder of images found at ' + FLAGS.image_dir +
- ' - multiple classes are needed for classification.')
+ tf.logging.error('Only one valid folder of images found at ' +
+ FLAGS.image_dir +
+ ' - multiple classes are needed for classification.')
return -1
# See if the command-line flags mean we're applying any distortions.
@@ -845,25 +1002,33 @@ def main(_):
FLAGS.random_brightness)
with tf.Session(graph=graph) as sess:
+ # Set up the image decoding sub-graph.
+ jpeg_data_tensor, decoded_image_tensor = add_jpeg_decoding(
+ model_info['input_width'], model_info['input_height'],
+ model_info['input_depth'], model_info['input_mean'],
+ model_info['input_std'])
if do_distort_images:
# We will be applying distortions, so setup the operations we'll need.
(distorted_jpeg_data_tensor,
distorted_image_tensor) = add_input_distortions(
- FLAGS.flip_left_right, FLAGS.random_crop,
- FLAGS.random_scale, FLAGS.random_brightness)
+ FLAGS.flip_left_right, FLAGS.random_crop, FLAGS.random_scale,
+ FLAGS.random_brightness, model_info['input_width'],
+ model_info['input_height'], model_info['input_depth'],
+ model_info['input_mean'], model_info['input_std'])
else:
# We'll make sure we've calculated the 'bottleneck' image summaries and
# cached them on disk.
cache_bottlenecks(sess, image_lists, FLAGS.image_dir,
FLAGS.bottleneck_dir, jpeg_data_tensor,
- bottleneck_tensor)
+ decoded_image_tensor, resized_image_tensor,
+ bottleneck_tensor, FLAGS.architecture)
# Add the new layer that we'll be training.
(train_step, cross_entropy, bottleneck_input, ground_truth_input,
- final_tensor) = add_final_training_ops(len(image_lists.keys()),
- FLAGS.final_tensor_name,
- bottleneck_tensor)
+ final_tensor) = add_final_training_ops(
+ len(image_lists.keys()), FLAGS.final_tensor_name, bottleneck_tensor,
+ model_info['bottleneck_tensor_size'])
# Create the operations we need to evaluate the accuracy of our new layer.
evaluation_step, prediction = add_evaluation_step(
@@ -896,10 +1061,10 @@ def main(_):
train_ground_truth, _) = get_random_cached_bottlenecks(
sess, image_lists, FLAGS.train_batch_size, 'training',
FLAGS.bottleneck_dir, FLAGS.image_dir, jpeg_data_tensor,
- bottleneck_tensor)
+ decoded_image_tensor, resized_image_tensor, bottleneck_tensor,
+ FLAGS.architecture)
# Feed the bottlenecks and ground truth into the graph, and run a training
# step. Capture training summaries for TensorBoard with the `merged` op.
-
train_summary, _ = sess.run(
[merged, train_step],
feed_dict={bottleneck_input: train_bottlenecks,
@@ -913,15 +1078,16 @@ def main(_):
[evaluation_step, cross_entropy],
feed_dict={bottleneck_input: train_bottlenecks,
ground_truth_input: train_ground_truth})
- print('%s: Step %d: Train accuracy = %.1f%%' % (datetime.now(), i,
- train_accuracy * 100))
- print('%s: Step %d: Cross entropy = %f' % (datetime.now(), i,
- cross_entropy_value))
+ tf.logging.info('%s: Step %d: Train accuracy = %.1f%%' %
+ (datetime.now(), i, train_accuracy * 100))
+ tf.logging.info('%s: Step %d: Cross entropy = %f' %
+ (datetime.now(), i, cross_entropy_value))
validation_bottlenecks, validation_ground_truth, _ = (
get_random_cached_bottlenecks(
sess, image_lists, FLAGS.validation_batch_size, 'validation',
FLAGS.bottleneck_dir, FLAGS.image_dir, jpeg_data_tensor,
- bottleneck_tensor))
+ decoded_image_tensor, resized_image_tensor, bottleneck_tensor,
+ FLAGS.architecture))
# Run a validation step and capture training summaries for TensorBoard
# with the `merged` op.
validation_summary, validation_accuracy = sess.run(
@@ -929,38 +1095,43 @@ def main(_):
feed_dict={bottleneck_input: validation_bottlenecks,
ground_truth_input: validation_ground_truth})
validation_writer.add_summary(validation_summary, i)
- print('%s: Step %d: Validation accuracy = %.1f%% (N=%d)' %
- (datetime.now(), i, validation_accuracy * 100,
- len(validation_bottlenecks)))
+ tf.logging.info('%s: Step %d: Validation accuracy = %.1f%% (N=%d)' %
+ (datetime.now(), i, validation_accuracy * 100,
+ len(validation_bottlenecks)))
# Store intermediate results
intermediate_frequency = FLAGS.intermediate_store_frequency
- if intermediate_frequency > 0 and (i % intermediate_frequency == 0) and i > 0:
- intermediate_file_name = FLAGS.intermediate_output_graphs_dir + 'intermediate_' + str(i) + '.pb'
- print('Save intermediate result to : ' + intermediate_file_name)
+ if (intermediate_frequency > 0 and (i % intermediate_frequency == 0)
+ and i > 0):
+ intermediate_file_name = (FLAGS.intermediate_output_graphs_dir +
+ 'intermediate_' + str(i) + '.pb')
+ tf.logging.info('Save intermediate result to : ' +
+ intermediate_file_name)
save_graph_to_file(sess, graph, intermediate_file_name)
-
+
# We've completed all our training, so run a final test evaluation on
# some new images we haven't used before.
test_bottlenecks, test_ground_truth, test_filenames = (
- get_random_cached_bottlenecks(sess, image_lists, FLAGS.test_batch_size,
- 'testing', FLAGS.bottleneck_dir,
- FLAGS.image_dir, jpeg_data_tensor,
- bottleneck_tensor))
+ get_random_cached_bottlenecks(
+ sess, image_lists, FLAGS.test_batch_size, 'testing',
+ FLAGS.bottleneck_dir, FLAGS.image_dir, jpeg_data_tensor,
+ decoded_image_tensor, resized_image_tensor, bottleneck_tensor,
+ FLAGS.architecture))
test_accuracy, predictions = sess.run(
[evaluation_step, prediction],
feed_dict={bottleneck_input: test_bottlenecks,
ground_truth_input: test_ground_truth})
- print('Final test accuracy = %.1f%% (N=%d)' % (
- test_accuracy * 100, len(test_bottlenecks)))
+ tf.logging.info('Final test accuracy = %.1f%% (N=%d)' %
+ (test_accuracy * 100, len(test_bottlenecks)))
if FLAGS.print_misclassified_test_images:
- print('=== MISCLASSIFIED TEST IMAGES ===')
+ tf.logging.info('=== MISCLASSIFIED TEST IMAGES ===')
for i, test_filename in enumerate(test_filenames):
if predictions[i] != test_ground_truth[i].argmax():
- print('%70s %s' % (test_filename,
- list(image_lists.keys())[predictions[i]]))
+ tf.logging.info('%70s %s' %
+ (test_filename,
+ list(image_lists.keys())[predictions[i]]))
# Write out the trained graph and labels with the weights stored as
# constants.
@@ -993,7 +1164,10 @@ if __name__ == '__main__':
'--intermediate_store_frequency',
type=int,
default=0,
- help='How many steps to store intermediate graph. If "0" then will not store.'
+ help="""\
+ How many steps to store intermediate graph. If "0" then will not
+ store.\
+ """
)
parser.add_argument(
'--output_labels',
@@ -1134,5 +1308,19 @@ if __name__ == '__main__':
input pixels up or down by.\
"""
)
+ parser.add_argument(
+ '--architecture',
+ type=str,
+ default='inception_v3',
+ help="""\
+ Which model architecture to use. 'inception_v3' is the most accurate, but
+ also the slowest. For faster or smaller models, chose a MobileNet with the
+ form 'mobilenet_<parameter size>_<input_size>[_quantized]'. For example,
+ 'mobilenet_1.0_224' will pick a model that is 17 MB in size and takes 224
+ pixel input images, while 'mobilenet_0.25_128_quantized' will choose a much
+ less accurate, but smaller and faster network that's 920 KB on disk and
+ takes 128x128 images. See https://research.googleblog.com/2017/06/mobilenets-open-source-models-for.html
+ for more information on Mobilenet.\
+ """)
FLAGS, unparsed = parser.parse_known_args()
tf.app.run(main=main, argv=[sys.argv[0]] + unparsed)
diff --git a/tensorflow/examples/image_retraining/retrain_test.py b/tensorflow/examples/image_retraining/retrain_test.py
index 8af5cc7114..467c15d0de 100644
--- a/tensorflow/examples/image_retraining/retrain_test.py
+++ b/tensorflow/examples/image_retraining/retrain_test.py
@@ -48,10 +48,10 @@ class ImageRetrainingTest(test_util.TensorFlowTestCase):
def testGetBottleneckPath(self):
image_lists = self.dummyImageLists()
- self.assertEqual('bottleneck_dir/somedir/image_five.jpg.txt',
+ self.assertEqual('bottleneck_dir/somedir/image_five.jpg_imagenet_v3.txt',
retrain.get_bottleneck_path(
image_lists, 'label_one', 0, 'bottleneck_dir',
- 'validation'))
+ 'validation', 'imagenet_v3'))
def testShouldDistortImage(self):
self.assertEqual(False, retrain.should_distort_images(False, 0, 0, 0))
@@ -63,7 +63,7 @@ class ImageRetrainingTest(test_util.TensorFlowTestCase):
def testAddInputDistortions(self):
with tf.Graph().as_default():
with tf.Session() as sess:
- retrain.add_input_distortions(True, 10, 10, 10)
+ retrain.add_input_distortions(True, 10, 10, 10, 299, 299, 3, 128, 128)
self.assertIsNotNone(sess.graph.get_tensor_by_name('DistortJPGInput:0'))
self.assertIsNotNone(sess.graph.get_tensor_by_name('DistortResult:0'))
@@ -72,9 +72,9 @@ class ImageRetrainingTest(test_util.TensorFlowTestCase):
with tf.Graph().as_default():
with tf.Session() as sess:
bottleneck = tf.placeholder(
- tf.float32, [1, retrain.BOTTLENECK_TENSOR_SIZE],
- name=retrain.BOTTLENECK_TENSOR_NAME.split(':')[0])
- retrain.add_final_training_ops(5, 'final', bottleneck)
+ tf.float32, [1, 1024],
+ name='bottleneck')
+ retrain.add_final_training_ops(5, 'final', bottleneck, 1024)
self.assertIsNotNone(sess.graph.get_tensor_by_name('final:0'))
def testAddEvaluationStep(self):
@@ -113,5 +113,22 @@ class ImageRetrainingTest(test_util.TensorFlowTestCase):
result = label_image.run_graph(image, labels, jpeg, 'final:0', 3)
self.assertEqual(result, 0)
+ def testAddJpegDecoding(self):
+ with tf.Graph().as_default():
+ jpeg_data, mul_image = retrain.add_jpeg_decoding(10, 10, 3, 0, 255)
+ self.assertIsNotNone(jpeg_data)
+ self.assertIsNotNone(mul_image)
+
+ def testCreateModelInfo(self):
+ did_raise_value_error = False
+ try:
+ retrain.create_model_info('no_such_model_name')
+ except ValueError:
+ did_raise_value_error = True
+ self.assertTrue(did_raise_value_error)
+ model_info = retrain.create_model_info('inception_v3')
+ self.assertIsNotNone(model_info)
+ self.assertEqual(299, model_info['input_width'])
+
if __name__ == '__main__':
tf.test.main()