aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/examples/image_retraining/retrain_test.py
diff options
context:
space:
mode:
Diffstat (limited to 'tensorflow/examples/image_retraining/retrain_test.py')
-rw-r--r--tensorflow/examples/image_retraining/retrain_test.py23
1 files changed, 4 insertions, 19 deletions
diff --git a/tensorflow/examples/image_retraining/retrain_test.py b/tensorflow/examples/image_retraining/retrain_test.py
index 2de4c4ec99..c342a17dd8 100644
--- a/tensorflow/examples/image_retraining/retrain_test.py
+++ b/tensorflow/examples/image_retraining/retrain_test.py
@@ -70,18 +70,10 @@ class ImageRetrainingTest(test_util.TensorFlowTestCase):
def testAddFinalTrainingOps(self, flags_mock):
with tf.Graph().as_default():
with tf.Session() as sess:
- bottleneck = tf.placeholder(tf.float32, [1, 1024], name='bottleneck')
- # Test creating final training op with quantization
- retrain.add_final_training_ops(5, 'final', bottleneck, 1024, False)
- self.assertIsNotNone(sess.graph.get_tensor_by_name('final:0'))
-
- @tf.test.mock.patch.object(retrain, 'FLAGS', learning_rate=0.01)
- def testAddFinalTrainingOpsQuantized(self, flags_mock):
- with tf.Graph().as_default():
- with tf.Session() as sess:
- bottleneck = tf.placeholder(tf.float32, [1, 1024], name='bottleneck')
- # Test creating final training op with quantization
- retrain.add_final_training_ops(5, 'final', bottleneck, 1024, True)
+ bottleneck = tf.placeholder(
+ tf.float32, [1, 1024],
+ name='bottleneck')
+ retrain.add_final_training_ops(5, 'final', bottleneck, 1024)
self.assertIsNotNone(sess.graph.get_tensor_by_name('final:0'))
def testAddEvaluationStep(self):
@@ -107,12 +99,5 @@ class ImageRetrainingTest(test_util.TensorFlowTestCase):
self.assertIsNotNone(model_info)
self.assertEqual(299, model_info['input_width'])
- def testCreateModelInfoQuantized(self):
- # Test for mobilenet_quantized
- model_info = retrain.create_model_info('mobilenet_1.0_224')
- self.assertIsNotNone(model_info)
- self.assertEqual(224, model_info['input_width'])
-
-
if __name__ == '__main__':
tf.test.main()