diff options
Diffstat (limited to 'tensorflow/examples/image_retraining/retrain_test.py')
-rw-r--r-- | tensorflow/examples/image_retraining/retrain_test.py | 23 |
1 files changed, 19 insertions, 4 deletions
diff --git a/tensorflow/examples/image_retraining/retrain_test.py b/tensorflow/examples/image_retraining/retrain_test.py index c342a17dd8..2de4c4ec99 100644 --- a/tensorflow/examples/image_retraining/retrain_test.py +++ b/tensorflow/examples/image_retraining/retrain_test.py @@ -70,10 +70,18 @@ class ImageRetrainingTest(test_util.TensorFlowTestCase): def testAddFinalTrainingOps(self, flags_mock): with tf.Graph().as_default(): with tf.Session() as sess: - bottleneck = tf.placeholder( - tf.float32, [1, 1024], - name='bottleneck') - retrain.add_final_training_ops(5, 'final', bottleneck, 1024) + bottleneck = tf.placeholder(tf.float32, [1, 1024], name='bottleneck') + # Test creating final training op with quantization + retrain.add_final_training_ops(5, 'final', bottleneck, 1024, False) + self.assertIsNotNone(sess.graph.get_tensor_by_name('final:0')) + + @tf.test.mock.patch.object(retrain, 'FLAGS', learning_rate=0.01) + def testAddFinalTrainingOpsQuantized(self, flags_mock): + with tf.Graph().as_default(): + with tf.Session() as sess: + bottleneck = tf.placeholder(tf.float32, [1, 1024], name='bottleneck') + # Test creating final training op with quantization + retrain.add_final_training_ops(5, 'final', bottleneck, 1024, True) self.assertIsNotNone(sess.graph.get_tensor_by_name('final:0')) def testAddEvaluationStep(self): @@ -99,5 +107,12 @@ class ImageRetrainingTest(test_util.TensorFlowTestCase): self.assertIsNotNone(model_info) self.assertEqual(299, model_info['input_width']) + def testCreateModelInfoQuantized(self): + # Test for mobilenet_quantized + model_info = retrain.create_model_info('mobilenet_1.0_224') + self.assertIsNotNone(model_info) + self.assertEqual(224, model_info['input_width']) + + if __name__ == '__main__': tf.test.main() |