aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/examples
diff options
context:
space:
mode:
authorGravatar Suharsh Sivakumar <suharshs@google.com>2018-02-26 18:04:55 -0800
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2018-02-26 18:08:39 -0800
commit19f18e377d8ee2f624406527b21444128da344df (patch)
tree7ec8747362efae370d9b792d0d2fde1b2129165e /tensorflow/examples
parent4aa3d3ce252a9af2e09cdbd5460262ccb5378a3a (diff)
Modify retrain script to output TFLite compatible quantized models.
-Also fix flaky input name selection introduced by last PR. -Also rely on tf.contrib.quantize to do graph transformations. -Also, update retrain script to use new float mobilenet_v1 and quantized mobilenet_v1 models. PiperOrigin-RevId: 187111533
Diffstat (limited to 'tensorflow/examples')
-rw-r--r--tensorflow/examples/image_retraining/retrain.py317
-rw-r--r--tensorflow/examples/image_retraining/retrain_test.py44
2 files changed, 229 insertions, 132 deletions
diff --git a/tensorflow/examples/image_retraining/retrain.py b/tensorflow/examples/image_retraining/retrain.py
index 25e09fecbf..99a71206ac 100644
--- a/tensorflow/examples/image_retraining/retrain.py
+++ b/tensorflow/examples/image_retraining/retrain.py
@@ -75,13 +75,16 @@ python tensorflow/examples/image_retraining/retrain.py \
--image_dir ~/flower_photos --architecture mobilenet_1.0_224
```
-Run quantized version of mobilenet:
+Run mobilenet, instrumented for quantization:
```bash
python tensorflow/examples/image_retraining/retrain.py \
- --image_dir ~/flower_photos/ --architecture mobilenet_1.0_224_quantized
+ --image_dir ~/flower_photos/ --architecture mobilenet_1.0_224_quant
```
+These instrumented models can be converted to fully quantized mobile models via
+TensorFlow Lite.
+
There are 32 different Mobilenet models to choose from, with a variety of file
size and latency options. The first number can be '1.0', '0.75', '0.50', or
'0.25' to control the size, and the second controls the input image size, either
@@ -121,7 +124,6 @@ import numpy as np
from six.moves import urllib
import tensorflow as tf
-from tensorflow.contrib.quantize.python import quant_ops
from tensorflow.python.framework import graph_util
from tensorflow.python.framework import tensor_shape
from tensorflow.python.platform import gfile
@@ -135,6 +137,9 @@ FLAGS = None
# need to update these to reflect the values in the network you're using.
MAX_NUM_IMAGES_PER_CLASS = 2 ** 27 - 1 # ~134M
+# The location where variable checkpoints will be stored.
+CHECKPOINT_NAME = '/tmp/_retrain_checkpoint'
+
def create_image_lists(image_dir, testing_percentage, validation_percentage):
"""Builds a list of training images from the file system.
@@ -745,9 +750,9 @@ def variable_summaries(var):
tf.summary.histogram('histogram', var)
-def add_final_training_ops(class_count, final_tensor_name, bottleneck_tensor,
- bottleneck_tensor_size, quantize_layer):
- """Adds a new softmax and fully-connected layer for training.
+def add_final_retrain_ops(class_count, final_tensor_name, bottleneck_tensor,
+ bottleneck_tensor_size, quantize_layer, is_training):
+ """Adds a new softmax and fully-connected layer for training and eval.
We need to retrain the top layer to identify our new classes, so this function
adds the right operations to the graph, along with some variables to hold the
@@ -763,7 +768,9 @@ def add_final_training_ops(class_count, final_tensor_name, bottleneck_tensor,
bottleneck_tensor: The output of the main CNN graph.
bottleneck_tensor_size: How many entries in the bottleneck vector.
quantize_layer: Boolean, specifying whether the newly added layer should be
- quantized.
+ instrumented for quantized.
+ is_training: Boolean, specifying whether the newly add layer is for training
+ or eval.
Returns:
The tensors for the training and cross entropy results, and tensors for the
@@ -778,50 +785,41 @@ def add_final_training_ops(class_count, final_tensor_name, bottleneck_tensor,
ground_truth_input = tf.placeholder(
tf.int64, [None], name='GroundTruthInput')
- # Organizing the following ops as `final_training_ops` so they're easier
- # to see in TensorBoard
- layer_name = 'final_training_ops'
+ # Organizing the following ops so they are easier to see in TensorBoard.
+ layer_name = 'final_retrain_ops'
with tf.name_scope(layer_name):
with tf.name_scope('weights'):
initial_value = tf.truncated_normal(
[bottleneck_tensor_size, class_count], stddev=0.001)
layer_weights = tf.Variable(initial_value, name='final_weights')
- if quantize_layer:
- quantized_layer_weights = quant_ops.MovingAvgQuantize(
- layer_weights, is_training=True)
- variable_summaries(quantized_layer_weights)
-
variable_summaries(layer_weights)
+
with tf.name_scope('biases'):
layer_biases = tf.Variable(tf.zeros([class_count]), name='final_biases')
- if quantize_layer:
- quantized_layer_biases = quant_ops.MovingAvgQuantize(
- layer_biases, is_training=True)
- variable_summaries(quantized_layer_biases)
-
variable_summaries(layer_biases)
with tf.name_scope('Wx_plus_b'):
- if quantize_layer:
- logits = tf.matmul(bottleneck_input,
- quantized_layer_weights) + quantized_layer_biases
- logits = quant_ops.MovingAvgQuantize(
- logits,
- init_min=-32.0,
- init_max=32.0,
- is_training=True,
- num_bits=8,
- narrow_range=False,
- ema_decay=0.5)
- tf.summary.histogram('pre_activations', logits)
- else:
- logits = tf.matmul(bottleneck_input, layer_weights) + layer_biases
- tf.summary.histogram('pre_activations', logits)
+ logits = tf.matmul(bottleneck_input, layer_weights) + layer_biases
+ tf.summary.histogram('pre_activations', logits)
final_tensor = tf.nn.softmax(logits, name=final_tensor_name)
+ # The tf.contrib.quantize functions rewrite the graph in place for
+ # quantization. The imported model graph has already been rewritten, so upon
+ # calling these rewrites, only the newly added final layer will be
+ # transformed.
+ if quantize_layer:
+ if is_training:
+ tf.contrib.quantize.create_training_graph()
+ else:
+ tf.contrib.quantize.create_eval_graph()
+
tf.summary.histogram('activations', final_tensor)
+ # If this is an eval graph, we don't need to add loss ops or an optimizer.
+ if not is_training:
+ return None, None, bottleneck_input, ground_truth_input, final_tensor
+
with tf.name_scope('cross_entropy'):
cross_entropy_mean = tf.losses.sparse_softmax_cross_entropy(
labels=ground_truth_input, logits=logits)
@@ -857,13 +855,91 @@ def add_evaluation_step(result_tensor, ground_truth_tensor):
return evaluation_step, prediction
-def save_graph_to_file(sess, graph, graph_file_name):
+def run_final_eval(sess, model_info, class_count, image_lists, jpeg_data_tensor,
+ decoded_image_tensor, resized_image_tensor,
+ bottleneck_tensor):
+ """Runs a final evaluation on an eval graph using the test data set.
+
+ Args:
+ sess: Session for the train graph.
+ model_info: Model info dictionary from create_model_info()
+ class_count: Number of classes
+ image_lists: Dictionary of training images for each label.
+ jpeg_data_tensor: The layer to feed jpeg image data into.
+ decoded_image_tensor: The output of decoding and resizing the image.
+ resized_image_tensor: The input node of the recognition graph.
+ bottleneck_tensor: The bottleneck output layer of the CNN graph.
+ """
+ (sess, bottleneck_input, ground_truth_input, evaluation_step,
+ prediction) = build_eval_session(model_info, class_count)
+
+ test_bottlenecks, test_ground_truth, test_filenames = (
+ get_random_cached_bottlenecks(sess, image_lists, FLAGS.test_batch_size,
+ 'testing', FLAGS.bottleneck_dir,
+ FLAGS.image_dir, jpeg_data_tensor,
+ decoded_image_tensor, resized_image_tensor,
+ bottleneck_tensor, FLAGS.architecture))
+ test_accuracy, predictions = sess.run(
+ [evaluation_step, prediction],
+ feed_dict={
+ bottleneck_input: test_bottlenecks,
+ ground_truth_input: test_ground_truth
+ })
+ tf.logging.info('Final test accuracy = %.1f%% (N=%d)' %
+ (test_accuracy * 100, len(test_bottlenecks)))
+
+ if FLAGS.print_misclassified_test_images:
+ tf.logging.info('=== MISCLASSIFIED TEST IMAGES ===')
+ for i, test_filename in enumerate(test_filenames):
+ if predictions[i] != test_ground_truth[i]:
+ tf.logging.info('%70s %s' % (test_filename,
+ list(image_lists.keys())[predictions[i]]))
+
+
+def build_eval_session(model_info, class_count):
+ """Builds an restored eval session without train operations for exporting.
+
+ Args:
+ model_info: Model info dictionary from create_model_info()
+ class_count: Number of classes
+
+ Returns:
+ Eval session containing the restored eval graph.
+ The bottleneck input, ground truth, eval step, and prediction tensors.
+ """
+ # If quantized, we need to create the correct eval graph for exporting.
+ eval_graph, bottleneck_tensor, _ = create_model_graph(model_info)
+
+ eval_sess = tf.Session(graph=eval_graph)
+ with eval_graph.as_default():
+ # Add the new layer for exporting.
+ (_, _, bottleneck_input,
+ ground_truth_input, final_tensor) = add_final_retrain_ops(
+ class_count, FLAGS.final_tensor_name, bottleneck_tensor,
+ model_info['bottleneck_tensor_size'], model_info['quantize_layer'],
+ False)
+
+ # Now we need to restore the values from the training graph to the eval
+ # graph.
+ tf.train.Saver().restore(eval_sess, CHECKPOINT_NAME)
+
+ evaluation_step, prediction = add_evaluation_step(final_tensor,
+ ground_truth_input)
+
+ return (eval_sess, bottleneck_input, ground_truth_input, evaluation_step,
+ prediction)
+
+
+def save_graph_to_file(graph, graph_file_name, model_info, class_count):
+ """Saves an graph to file, creating a valid quantized one if necessary."""
+ sess, _, _, _, _ = build_eval_session(model_info, class_count)
+ graph = sess.graph
+
output_graph_def = graph_util.convert_variables_to_constants(
sess, graph.as_graph_def(), [FLAGS.final_tensor_name])
with gfile.FastGFile(graph_file_name, 'wb') as f:
f.write(output_graph_def.SerializeToString())
- return
def prepare_file_system():
@@ -916,11 +992,10 @@ def create_model_info(architecture):
return None
version_string = parts[1]
if (version_string != '1.0' and version_string != '0.75' and
- version_string != '0.50' and version_string != '0.25'):
+ version_string != '0.5' and version_string != '0.25'):
tf.logging.error(
- """"The Mobilenet version should be '1.0', '0.75', '0.50', or '0.25',
- but found '%s' for architecture '%s'""",
- version_string, architecture)
+ """"The Mobilenet version should be '1.0', '0.75', '0.5', or '0.25',
+ but found '%s' for architecture '%s'""", version_string, architecture)
return None
size_string = parts[2]
if (size_string != '224' and size_string != '192' and
@@ -933,35 +1008,26 @@ def create_model_info(architecture):
if len(parts) == 3:
is_quantized = False
else:
- if parts[3] != 'quantized':
+ if parts[3] != 'quant':
tf.logging.error(
"Couldn't understand architecture suffix '%s' for '%s'", parts[3],
architecture)
return None
is_quantized = True
+ data_url = 'http://download.tensorflow.org/models/mobilenet_v1_2018_02_22/'
+ model_name = 'mobilenet_v1_' + version_string + '_' + size_string
if is_quantized:
- data_url = 'http://download.tensorflow.org/models/mobilenet_v1_'
- data_url += version_string + '_' + size_string + '_quantized_frozen.tgz'
- bottleneck_tensor_name = 'MobilenetV1/Predictions/Reshape:0'
- resized_input_tensor_name = 'Placeholder:0'
- model_dir_name = ('mobilenet_v1_' + version_string + '_' + size_string +
- '_quantized_frozen')
- model_base_name = 'quantized_frozen_graph.pb'
-
- else:
- data_url = 'http://download.tensorflow.org/models/mobilenet_v1_'
- data_url += version_string + '_' + size_string + '_frozen.tgz'
- bottleneck_tensor_name = 'MobilenetV1/Predictions/Reshape:0'
- resized_input_tensor_name = 'input:0'
- model_dir_name = 'mobilenet_v1_' + version_string + '_' + size_string
- model_base_name = 'frozen_graph.pb'
+ model_name += '_quant'
+ data_url += model_name + '.tgz'
+ bottleneck_tensor_name = 'MobilenetV1/Predictions/Reshape:0'
+ resized_input_tensor_name = 'input:0'
+ model_file_name = model_name + '_frozen.pb'
bottleneck_tensor_size = 1001
input_width = int(size_string)
input_height = int(size_string)
input_depth = 3
- model_file_name = os.path.join(model_dir_name, model_base_name)
input_mean = 127.5
input_std = 127.5
else:
@@ -1011,43 +1077,45 @@ def add_jpeg_decoding(input_width, input_height, input_depth, input_mean,
return jpeg_data, mul_image
-def export_model(sess, architecture, saved_model_dir):
+def export_model(model_info, class_count, saved_model_dir):
"""Exports model for serving.
Args:
- sess: Current active TensorFlow Session.
- architecture: Model architecture.
+ model_info: The modelinfo for the current model.
+ class_count: The number of classes.
saved_model_dir: Directory in which to save exported model and variables.
"""
- if architecture == 'inception_v3':
- input_tensor = 'DecodeJpeg/contents:0'
- elif architecture.startswith('mobilenet_'):
- input_tensor = 'input:0'
- else:
- raise ValueError('Unknown architecture', architecture)
- in_image = sess.graph.get_tensor_by_name(input_tensor)
- inputs = {'image': tf.saved_model.utils.build_tensor_info(in_image)}
-
- out_classes = sess.graph.get_tensor_by_name('final_result:0')
- outputs = {'prediction': tf.saved_model.utils.build_tensor_info(out_classes)}
+ # The SavedModel should hold the eval graph.
+ sess, _, _, _, _ = build_eval_session(model_info, class_count)
+ graph = sess.graph
+ with graph.as_default():
+ input_tensor = model_info['resized_input_tensor_name']
+ in_image = sess.graph.get_tensor_by_name(input_tensor)
+ inputs = {'image': tf.saved_model.utils.build_tensor_info(in_image)}
+
+ out_classes = sess.graph.get_tensor_by_name('final_result:0')
+ outputs = {
+ 'prediction': tf.saved_model.utils.build_tensor_info(out_classes)
+ }
- signature = tf.saved_model.signature_def_utils.build_signature_def(
- inputs=inputs,
- outputs=outputs,
- method_name=tf.saved_model.signature_constants.PREDICT_METHOD_NAME)
+ signature = tf.saved_model.signature_def_utils.build_signature_def(
+ inputs=inputs,
+ outputs=outputs,
+ method_name=tf.saved_model.signature_constants.PREDICT_METHOD_NAME)
- legacy_init_op = tf.group(tf.tables_initializer(), name='legacy_init_op')
+ legacy_init_op = tf.group(tf.tables_initializer(), name='legacy_init_op')
- # Save out the SavedModel.
- builder = tf.saved_model.builder.SavedModelBuilder(saved_model_dir)
- builder.add_meta_graph_and_variables(
- sess, [tf.saved_model.tag_constants.SERVING],
- signature_def_map={
- tf.saved_model.signature_constants.DEFAULT_SERVING_SIGNATURE_DEF_KEY:
- signature
- },
- legacy_init_op=legacy_init_op)
- builder.save()
+ # Save out the SavedModel.
+ builder = tf.saved_model.builder.SavedModelBuilder(saved_model_dir)
+ builder.add_meta_graph_and_variables(
+ sess, [tf.saved_model.tag_constants.SERVING],
+ signature_def_map={
+ tf.saved_model.signature_constants.
+ DEFAULT_SERVING_SIGNATURE_DEF_KEY:
+ signature
+ },
+ legacy_init_op=legacy_init_op)
+ builder.save()
def main(_):
@@ -1064,11 +1132,6 @@ def main(_):
tf.logging.error('Did not recognize architecture flag')
return -1
- # Set up the pre-trained graph.
- maybe_download_and_extract(model_info['data_url'])
- graph, bottleneck_tensor, resized_image_tensor = (
- create_model_graph(model_info))
-
# Look at the folder structure, and create lists of all the images.
image_lists = create_image_lists(FLAGS.image_dir, FLAGS.testing_percentage,
FLAGS.validation_percentage)
@@ -1087,6 +1150,19 @@ def main(_):
FLAGS.flip_left_right, FLAGS.random_crop, FLAGS.random_scale,
FLAGS.random_brightness)
+ # Set up the pre-trained graph.
+ maybe_download_and_extract(model_info['data_url'])
+ graph, bottleneck_tensor, resized_image_tensor = (
+ create_model_graph(model_info))
+
+ # Add the new layer that we'll be training.
+ with graph.as_default():
+ (train_step, cross_entropy, bottleneck_input,
+ ground_truth_input, final_tensor) = add_final_retrain_ops(
+ class_count, FLAGS.final_tensor_name, bottleneck_tensor,
+ model_info['bottleneck_tensor_size'], model_info['quantize_layer'],
+ True)
+
with tf.Session(graph=graph) as sess:
# Set up the image decoding sub-graph.
jpeg_data_tensor, decoded_image_tensor = add_jpeg_decoding(
@@ -1110,15 +1186,8 @@ def main(_):
decoded_image_tensor, resized_image_tensor,
bottleneck_tensor, FLAGS.architecture)
- # Add the new layer that we'll be training.
- (train_step, cross_entropy, bottleneck_input, ground_truth_input,
- final_tensor) = add_final_training_ops(
- len(image_lists.keys()), FLAGS.final_tensor_name, bottleneck_tensor,
- model_info['bottleneck_tensor_size'], model_info['quantize_layer'])
-
# Create the operations we need to evaluate the accuracy of our new layer.
- evaluation_step, prediction = add_evaluation_step(
- final_tensor, ground_truth_input)
+ evaluation_step, _ = add_evaluation_step(final_tensor, ground_truth_input)
# Merge all the summaries and write them out to the summaries_dir
merged = tf.summary.merge_all()
@@ -1128,6 +1197,10 @@ def main(_):
validation_writer = tf.summary.FileWriter(
FLAGS.summaries_dir + '/validation')
+ # Create a train saver that is used to restore values into an eval graph
+ # when exporting models.
+ train_saver = tf.train.Saver()
+
# Set up all our weights to their initial default values.
init = tf.global_variables_initializer()
sess.run(init)
@@ -1168,6 +1241,9 @@ def main(_):
(datetime.now(), i, train_accuracy * 100))
tf.logging.info('%s: Step %d: Cross entropy = %f' %
(datetime.now(), i, cross_entropy_value))
+ # TODO(suharshs): Make this use an eval graph, to avoid quantization
+ # moving averages being updated by the validation set, though in
+ # practice this makes a negligable difference.
validation_bottlenecks, validation_ground_truth, _ = (
get_random_cached_bottlenecks(
sess, image_lists, FLAGS.validation_batch_size, 'validation',
@@ -1190,42 +1266,32 @@ def main(_):
if (intermediate_frequency > 0 and (i % intermediate_frequency == 0)
and i > 0):
+ # If we want to do an intermediate save, save a checkpoint of the train
+ # graph, to restore into the eval graph.
+ train_saver.save(sess, CHECKPOINT_NAME)
intermediate_file_name = (FLAGS.intermediate_output_graphs_dir +
'intermediate_' + str(i) + '.pb')
tf.logging.info('Save intermediate result to : ' +
intermediate_file_name)
- save_graph_to_file(sess, graph, intermediate_file_name)
+ save_graph_to_file(graph, intermediate_file_name, model_info,
+ class_count)
+
+ # After training is complete, force one last save of the train checkpoint.
+ train_saver.save(sess, CHECKPOINT_NAME)
# We've completed all our training, so run a final test evaluation on
# some new images we haven't used before.
- test_bottlenecks, test_ground_truth, test_filenames = (
- get_random_cached_bottlenecks(
- sess, image_lists, FLAGS.test_batch_size, 'testing',
- FLAGS.bottleneck_dir, FLAGS.image_dir, jpeg_data_tensor,
- decoded_image_tensor, resized_image_tensor, bottleneck_tensor,
- FLAGS.architecture))
- test_accuracy, predictions = sess.run(
- [evaluation_step, prediction],
- feed_dict={bottleneck_input: test_bottlenecks,
- ground_truth_input: test_ground_truth})
- tf.logging.info('Final test accuracy = %.1f%% (N=%d)' %
- (test_accuracy * 100, len(test_bottlenecks)))
-
- if FLAGS.print_misclassified_test_images:
- tf.logging.info('=== MISCLASSIFIED TEST IMAGES ===')
- for i, test_filename in enumerate(test_filenames):
- if predictions[i] != test_ground_truth[i]:
- tf.logging.info('%70s %s' %
- (test_filename,
- list(image_lists.keys())[predictions[i]]))
+ run_final_eval(sess, model_info, class_count, image_lists, jpeg_data_tensor,
+ decoded_image_tensor, resized_image_tensor,
+ bottleneck_tensor)
# Write out the trained graph and labels with the weights stored as
# constants.
- save_graph_to_file(sess, graph, FLAGS.output_graph)
+ save_graph_to_file(graph, FLAGS.output_graph, model_info, class_count)
with gfile.FastGFile(FLAGS.output_labels, 'w') as f:
f.write('\n'.join(image_lists.keys()) + '\n')
- export_model(sess, FLAGS.architecture, FLAGS.saved_model_dir)
+ export_model(model_info, class_count, FLAGS.saved_model_dir)
if __name__ == '__main__':
@@ -1406,8 +1472,9 @@ if __name__ == '__main__':
form 'mobilenet_<parameter size>_<input_size>[_quantized]'. For example,
'mobilenet_1.0_224' will pick a model that is 17 MB in size and takes 224
pixel input images, while 'mobilenet_0.25_128_quantized' will choose a much
- less accurate, but smaller and faster network that's 920 KB on disk and
- takes 128x128 images. See https://research.googleblog.com/2017/06/mobilenets-open-source-models-for.html
+ smaller and less accurate model, taking 128x128 images, and instrumented
+ for eventual quantization via TensorFlow Lite.
+ See https://research.googleblog.com/2017/06/mobilenets-open-source-models-for.html
for more information on Mobilenet.\
""")
parser.add_argument(
diff --git a/tensorflow/examples/image_retraining/retrain_test.py b/tensorflow/examples/image_retraining/retrain_test.py
index 8b8dd45fd7..fb7324c58a 100644
--- a/tensorflow/examples/image_retraining/retrain_test.py
+++ b/tensorflow/examples/image_retraining/retrain_test.py
@@ -67,22 +67,52 @@ class ImageRetrainingTest(test_util.TensorFlowTestCase):
self.assertIsNotNone(sess.graph.get_tensor_by_name('DistortResult:0'))
@tf.test.mock.patch.object(retrain, 'FLAGS', learning_rate=0.01)
- def testAddFinalTrainingOps(self, flags_mock):
+ def testAddFinalRetrainOps(self, flags_mock):
with tf.Graph().as_default():
with tf.Session() as sess:
bottleneck = tf.placeholder(tf.float32, [1, 1024], name='bottleneck')
- # Test creating final training op with quantization
- retrain.add_final_training_ops(5, 'final', bottleneck, 1024, False)
+ # Test creating final training op with quantization.
+ retrain.add_final_retrain_ops(5, 'final', bottleneck, 1024, False,
+ False)
self.assertIsNotNone(sess.graph.get_tensor_by_name('final:0'))
@tf.test.mock.patch.object(retrain, 'FLAGS', learning_rate=0.01)
- def testAddFinalTrainingOpsQuantized(self, flags_mock):
- with tf.Graph().as_default():
+ def testAddFinalRetrainOpsQuantized(self, flags_mock):
+ # Ensure that the training and eval graph for quantized models are correctly
+ # created.
+ with tf.Graph().as_default() as g:
+ with tf.Session() as sess:
+ bottleneck = tf.placeholder(tf.float32, [1, 1024], name='bottleneck')
+ # Test creating final training op with quantization, set is_training to
+ # true.
+ retrain.add_final_retrain_ops(5, 'final', bottleneck, 1024, True, True)
+ self.assertIsNotNone(sess.graph.get_tensor_by_name('final:0'))
+ found_fake_quant = 0
+ for op in g.get_operations():
+ if op.type == 'FakeQuantWithMinMaxVars':
+ found_fake_quant += 1
+ # Ensure that the inputs of each FakeQuant operations has 2 Assign
+ # operations in the training graph (Assign[Min,Max]Last,
+ # Assign[Min,Max]Ema)
+ self.assertEqual(2,
+ len([i for i in op.inputs if 'Assign' in i.name]))
+ self.assertEqual(found_fake_quant, 2)
+ with tf.Graph().as_default() as g:
with tf.Session() as sess:
bottleneck = tf.placeholder(tf.float32, [1, 1024], name='bottleneck')
- # Test creating final training op with quantization
- retrain.add_final_training_ops(5, 'final', bottleneck, 1024, True)
+ # Test creating final training op with quantization, set is_training to
+ # false.
+ retrain.add_final_retrain_ops(5, 'final', bottleneck, 1024, True, False)
self.assertIsNotNone(sess.graph.get_tensor_by_name('final:0'))
+ found_fake_quant = 0
+ for op in g.get_operations():
+ if op.type == 'FakeQuantWithMinMaxVars':
+ found_fake_quant += 1
+ for i in op.inputs:
+ # Ensure that no operations are Assign operation since this is the
+ # evaluation graph.
+ self.assertTrue('Assign' not in i.name)
+ self.assertEqual(found_fake_quant, 2)
def testAddEvaluationStep(self):
with tf.Graph().as_default():