aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/examples/image_retraining
diff options
context:
space:
mode:
authorGravatar Suharsh Sivakumar <suharshs@google.com>2017-11-07 18:38:12 -0800
committerGravatar Andrew Selle <aselle@andyselle.com>2017-11-10 16:14:34 -0800
commit54837e40096c35322e75d43a13bbf44c933f59db (patch)
tree2009c94347ac8149c8feac9f10831e77a5214842 /tensorflow/examples/image_retraining
parent6bc5375cb07d8d595411ec0516d29314053a8e83 (diff)
Add functionality to perform training of additional fixed point layer on top of quantized base model.
Also modify retrain_test to test creation of model info for fixed point mobilenet. PiperOrigin-RevId: 174946745
Diffstat (limited to 'tensorflow/examples/image_retraining')
-rw-r--r--tensorflow/examples/image_retraining/retrain.py82
-rw-r--r--tensorflow/examples/image_retraining/retrain_test.py23
2 files changed, 85 insertions, 20 deletions
diff --git a/tensorflow/examples/image_retraining/retrain.py b/tensorflow/examples/image_retraining/retrain.py
index 3549891461..ebddfb20f4 100644
--- a/tensorflow/examples/image_retraining/retrain.py
+++ b/tensorflow/examples/image_retraining/retrain.py
@@ -69,11 +69,18 @@ to validate that you have gathered good training data, but if you want to deploy
on resource-limited platforms, you can try the `--architecture` flag with a
Mobilenet model. For example:
+Run floating-point version of mobilenet:
```bash
python tensorflow/examples/image_retraining/retrain.py \
--image_dir ~/flower_photos --architecture mobilenet_1.0_224
```
+Run quantized version of mobilenet:
+```bash
+python tensorflow/examples/image_retraining/retrain.py \
+ --image_dir ~/flower_photos/ --architecture mobilenet_1.0_224_quantized
+```
+
There are 32 different Mobilenet models to choose from, with a variety of file
size and latency options. The first number can be '1.0', '0.75', '0.50', or
'0.25' to control the size, and the second controls the input image size, either
@@ -107,6 +114,7 @@ import numpy as np
from six.moves import urllib
import tensorflow as tf
+from tensorflow.contrib.quantize.python import quant_ops
from tensorflow.python.framework import graph_util
from tensorflow.python.framework import tensor_shape
from tensorflow.python.platform import gfile
@@ -271,6 +279,7 @@ def create_model_graph(model_info):
"""
with tf.Graph().as_default() as graph:
model_path = os.path.join(FLAGS.model_dir, model_info['model_file_name'])
+ print('Model path: ', model_path)
with gfile.FastGFile(model_path, 'rb') as f:
graph_def = tf.GraphDef()
graph_def.ParseFromString(f.read())
@@ -337,7 +346,10 @@ def maybe_download_and_extract(data_url):
statinfo = os.stat(filepath)
tf.logging.info('Successfully downloaded', filename, statinfo.st_size,
'bytes.')
- tarfile.open(filepath, 'r:gz').extractall(dest_directory)
+ print('Extracting file from ', filepath)
+ tarfile.open(filepath, 'r:gz').extractall(dest_directory)
+ else:
+ print('Not extracting or downloading files, model already present in disk')
def ensure_dir_exists(dir_name):
@@ -733,7 +745,7 @@ def variable_summaries(var):
def add_final_training_ops(class_count, final_tensor_name, bottleneck_tensor,
- bottleneck_tensor_size):
+ bottleneck_tensor_size, quantize_layer):
"""Adds a new softmax and fully-connected layer for training.
We need to retrain the top layer to identify our new classes, so this function
@@ -745,10 +757,12 @@ def add_final_training_ops(class_count, final_tensor_name, bottleneck_tensor,
Args:
class_count: Integer of how many categories of things we're trying to
- recognize.
+ recognize.
final_tensor_name: Name string for the new final node that produces results.
bottleneck_tensor: The output of the main CNN graph.
bottleneck_tensor_size: How many entries in the bottleneck vector.
+ quantize_layer: Boolean, specifying whether the newly added layer should be
+ quantized.
Returns:
The tensors for the training and cross entropy results, and tensors for the
@@ -771,18 +785,41 @@ def add_final_training_ops(class_count, final_tensor_name, bottleneck_tensor,
with tf.name_scope('weights'):
initial_value = tf.truncated_normal(
[bottleneck_tensor_size, class_count], stddev=0.001)
-
layer_weights = tf.Variable(initial_value, name='final_weights')
+ if quantize_layer:
+ quantized_layer_weights = quant_ops.MovingAvgQuantize(
+ layer_weights, is_training=True)
+ variable_summaries(quantized_layer_weights)
variable_summaries(layer_weights)
with tf.name_scope('biases'):
layer_biases = tf.Variable(tf.zeros([class_count]), name='final_biases')
+ if quantize_layer:
+ quantized_layer_biases = quant_ops.MovingAvgQuantize(
+ layer_biases, is_training=True)
+ variable_summaries(quantized_layer_biases)
+
variable_summaries(layer_biases)
+
with tf.name_scope('Wx_plus_b'):
- logits = tf.matmul(bottleneck_input, layer_weights) + layer_biases
- tf.summary.histogram('pre_activations', logits)
+ if quantize_layer:
+ logits = tf.matmul(bottleneck_input,
+ quantized_layer_weights) + quantized_layer_biases
+ logits = quant_ops.MovingAvgQuantize(
+ logits,
+ init_min=-32.0,
+ init_max=32.0,
+ is_training=True,
+ num_bits=8,
+ narrow_range=False,
+ ema_decay=0.5)
+ tf.summary.histogram('pre_activations', logits)
+ else:
+ logits = tf.matmul(bottleneck_input, layer_weights) + layer_biases
+ tf.summary.histogram('pre_activations', logits)
final_tensor = tf.nn.softmax(logits, name=final_tensor_name)
+
tf.summary.histogram('activations', final_tensor)
with tf.name_scope('cross_entropy'):
@@ -790,6 +827,7 @@ def add_final_training_ops(class_count, final_tensor_name, bottleneck_tensor,
labels=ground_truth_input, logits=logits)
with tf.name_scope('total'):
cross_entropy_mean = tf.reduce_mean(cross_entropy)
+
tf.summary.scalar('cross_entropy', cross_entropy_mean)
with tf.name_scope('train'):
@@ -825,6 +863,7 @@ def add_evaluation_step(result_tensor, ground_truth_tensor):
def save_graph_to_file(sess, graph, graph_file_name):
output_graph_def = graph_util.convert_variables_to_constants(
sess, graph.as_graph_def(), [FLAGS.final_tensor_name])
+
with gfile.FastGFile(graph_file_name, 'wb') as f:
f.write(output_graph_def.SerializeToString())
return
@@ -858,6 +897,7 @@ def create_model_info(architecture):
ValueError: If architecture name is unknown.
"""
architecture = architecture.lower()
+ is_quantized = False
if architecture == 'inception_v3':
# pylint: disable=line-too-long
data_url = 'http://download.tensorflow.org/models/image/imagenet/inception-2015-12-05.tgz'
@@ -902,19 +942,28 @@ def create_model_info(architecture):
architecture)
return None
is_quantized = True
- data_url = 'http://download.tensorflow.org/models/mobilenet_v1_'
- data_url += version_string + '_' + size_string + '_frozen.tgz'
- bottleneck_tensor_name = 'MobilenetV1/Predictions/Reshape:0'
+
+ if is_quantized:
+ data_url = 'http://download.tensorflow.org/models/mobilenet_v1_'
+ data_url += version_string + '_' + size_string + '_quantized_frozen.tgz'
+ bottleneck_tensor_name = 'MobilenetV1/Predictions/Reshape:0'
+ resized_input_tensor_name = 'Placeholder:0'
+ model_dir_name = ('mobilenet_v1_' + version_string + '_' + size_string +
+ '_quantized_frozen')
+ model_base_name = 'quantized_frozen_graph.pb'
+
+ else:
+ data_url = 'http://download.tensorflow.org/models/mobilenet_v1_'
+ data_url += version_string + '_' + size_string + '_frozen.tgz'
+ bottleneck_tensor_name = 'MobilenetV1/Predictions/Reshape:0'
+ resized_input_tensor_name = 'input:0'
+ model_dir_name = 'mobilenet_v1_' + version_string + '_' + size_string
+ model_base_name = 'frozen_graph.pb'
+
bottleneck_tensor_size = 1001
input_width = int(size_string)
input_height = int(size_string)
input_depth = 3
- resized_input_tensor_name = 'input:0'
- if is_quantized:
- model_base_name = 'quantized_graph.pb'
- else:
- model_base_name = 'frozen_graph.pb'
- model_dir_name = 'mobilenet_v1_' + version_string + '_' + size_string
model_file_name = os.path.join(model_dir_name, model_base_name)
input_mean = 127.5
input_std = 127.5
@@ -933,6 +982,7 @@ def create_model_info(architecture):
'model_file_name': model_file_name,
'input_mean': input_mean,
'input_std': input_std,
+ 'quantize_layer': is_quantized,
}
@@ -1028,7 +1078,7 @@ def main(_):
(train_step, cross_entropy, bottleneck_input, ground_truth_input,
final_tensor) = add_final_training_ops(
len(image_lists.keys()), FLAGS.final_tensor_name, bottleneck_tensor,
- model_info['bottleneck_tensor_size'])
+ model_info['bottleneck_tensor_size'], model_info['quantize_layer'])
# Create the operations we need to evaluate the accuracy of our new layer.
evaluation_step, prediction = add_evaluation_step(
diff --git a/tensorflow/examples/image_retraining/retrain_test.py b/tensorflow/examples/image_retraining/retrain_test.py
index c342a17dd8..2de4c4ec99 100644
--- a/tensorflow/examples/image_retraining/retrain_test.py
+++ b/tensorflow/examples/image_retraining/retrain_test.py
@@ -70,10 +70,18 @@ class ImageRetrainingTest(test_util.TensorFlowTestCase):
def testAddFinalTrainingOps(self, flags_mock):
with tf.Graph().as_default():
with tf.Session() as sess:
- bottleneck = tf.placeholder(
- tf.float32, [1, 1024],
- name='bottleneck')
- retrain.add_final_training_ops(5, 'final', bottleneck, 1024)
+ bottleneck = tf.placeholder(tf.float32, [1, 1024], name='bottleneck')
+ # Test creating final training op with quantization
+ retrain.add_final_training_ops(5, 'final', bottleneck, 1024, False)
+ self.assertIsNotNone(sess.graph.get_tensor_by_name('final:0'))
+
+ @tf.test.mock.patch.object(retrain, 'FLAGS', learning_rate=0.01)
+ def testAddFinalTrainingOpsQuantized(self, flags_mock):
+ with tf.Graph().as_default():
+ with tf.Session() as sess:
+ bottleneck = tf.placeholder(tf.float32, [1, 1024], name='bottleneck')
+ # Test creating final training op with quantization
+ retrain.add_final_training_ops(5, 'final', bottleneck, 1024, True)
self.assertIsNotNone(sess.graph.get_tensor_by_name('final:0'))
def testAddEvaluationStep(self):
@@ -99,5 +107,12 @@ class ImageRetrainingTest(test_util.TensorFlowTestCase):
self.assertIsNotNone(model_info)
self.assertEqual(299, model_info['input_width'])
+ def testCreateModelInfoQuantized(self):
+ # Test for mobilenet_quantized
+ model_info = retrain.create_model_info('mobilenet_1.0_224')
+ self.assertIsNotNone(model_info)
+ self.assertEqual(224, model_info['input_width'])
+
+
if __name__ == '__main__':
tf.test.main()