diff options
author | Dan Mané <danmane@google.com> | 2016-08-11 11:20:02 -0800 |
---|---|---|
committer | TensorFlower Gardener <gardener@tensorflow.org> | 2016-08-11 12:31:38 -0700 |
commit | d9527a0a622cb850f6ef2259f40bbc3c84a8a475 (patch) | |
tree | eec27205fb0503acfe35440502b74c559d1d818e /tensorflow/examples/image_retraining | |
parent | c5dfa2e4231dd5acd2ffb5f54d22e907edadf61e (diff) |
Automated rollback of change 129807750
Change: 130017620
Diffstat (limited to 'tensorflow/examples/image_retraining')
-rw-r--r-- | tensorflow/examples/image_retraining/BUILD | 9 | ||||
-rw-r--r-- | tensorflow/examples/image_retraining/retrain.py | 18 | ||||
-rw-r--r-- | tensorflow/examples/image_retraining/retrain_test.py | 32 |
3 files changed, 11 insertions, 48 deletions
diff --git a/tensorflow/examples/image_retraining/BUILD b/tensorflow/examples/image_retraining/BUILD index b1f83f76d5..4cf6adecb9 100644 --- a/tensorflow/examples/image_retraining/BUILD +++ b/tensorflow/examples/image_retraining/BUILD @@ -21,14 +21,9 @@ py_test( name = "retrain_test", size = "small", srcs = [ - "label_image.py", "retrain.py", "retrain_test.py", ], - data = [ - ":data/labels.txt", - "//tensorflow/examples/label_image:data/grace_hopper.jpg", - ], srcs_version = "PY2AND3", deps = [ "//tensorflow:tensorflow_py", @@ -40,9 +35,7 @@ py_test( filegroup( name = "all_files", srcs = glob( - [ - "**/*", - ], + ["**/*"], exclude = [ "**/METADATA", "**/OWNERS", diff --git a/tensorflow/examples/image_retraining/retrain.py b/tensorflow/examples/image_retraining/retrain.py index 8ca3c570d7..6a3024d5bc 100644 --- a/tensorflow/examples/image_retraining/retrain.py +++ b/tensorflow/examples/image_retraining/retrain.py @@ -79,6 +79,8 @@ import tensorflow as tf from tensorflow.python.framework import graph_util from tensorflow.python.framework import tensor_shape +from tensorflow.python.platform import gfile + FLAGS = tf.app.flags.FLAGS @@ -178,7 +180,7 @@ def create_image_lists(image_dir, testing_percentage, validation_percentage): A dictionary containing an entry for each label subfolder, with images split into training, testing, and validation sets within each label. """ - if not tf.gfile.Exists(image_dir): + if not gfile.Exists(image_dir): print("Image directory '" + image_dir + "' not found.") return None result = {} @@ -301,7 +303,7 @@ def create_inception_graph(): with tf.Session() as sess: model_filename = os.path.join( FLAGS.model_dir, 'classify_image_graph_def.pb') - with tf.gfile.FastGFile(model_filename, 'rb') as f: + with gfile.FastGFile(model_filename, 'rb') as f: graph_def = tf.GraphDef() graph_def.ParseFromString(f.read()) bottleneck_tensor, jpeg_data_tensor, resized_input_tensor = ( @@ -404,9 +406,9 @@ def get_or_create_bottleneck(sess, image_lists, label_name, index, image_dir, print('Creating bottleneck at ' + bottleneck_path) image_path = get_image_path(image_lists, label_name, index, image_dir, category) - if not tf.gfile.Exists(image_path): + if not gfile.Exists(image_path): tf.logging.fatal('File does not exist %s', image_path) - image_data = tf.gfile.FastGFile(image_path, 'rb').read() + image_data = gfile.FastGFile(image_path, 'rb').read() bottleneck_values = run_bottleneck_on_image(sess, image_data, jpeg_data_tensor, bottleneck_tensor) @@ -535,9 +537,9 @@ def get_random_distorted_bottlenecks( image_index = random.randrange(65536) image_path = get_image_path(image_lists, label_name, image_index, image_dir, category) - if not tf.gfile.Exists(image_path): + if not gfile.Exists(image_path): tf.logging.fatal('File does not exist %s', image_path) - jpeg_data = tf.gfile.FastGFile(image_path, 'rb').read() + jpeg_data = gfile.FastGFile(image_path, 'rb').read() # Note that we materialize the distorted_image_data as a numpy array before # sending running inference on the image. This involves 2 memory copies and # might be optimized in other implementations. @@ -876,9 +878,9 @@ def main(_): # Write out the trained graph and labels with the weights stored as constants. output_graph_def = graph_util.convert_variables_to_constants( sess, graph.as_graph_def(), [FLAGS.final_tensor_name]) - with tf.gfile.FastGFile(FLAGS.output_graph, 'wb') as f: + with gfile.FastGFile(FLAGS.output_graph, 'wb') as f: f.write(output_graph_def.SerializeToString()) - with tf.gfile.FastGFile(FLAGS.output_labels, 'w') as f: + with gfile.FastGFile(FLAGS.output_labels, 'w') as f: f.write('\n'.join(image_lists.keys()) + '\n') diff --git a/tensorflow/examples/image_retraining/retrain_test.py b/tensorflow/examples/image_retraining/retrain_test.py index fb9acea3eb..072998ae60 100644 --- a/tensorflow/examples/image_retraining/retrain_test.py +++ b/tensorflow/examples/image_retraining/retrain_test.py @@ -19,9 +19,7 @@ from __future__ import division from __future__ import print_function import tensorflow as tf -import os -from tensorflow.examples.image_retraining import label_image from tensorflow.examples.image_retraining import retrain from tensorflow.python.framework import test_util @@ -82,35 +80,5 @@ class ImageRetrainingTest(test_util.TensorFlowTestCase): gt = tf.placeholder(tf.float32, [1], name='gt') self.assertIsNotNone(retrain.add_evaluation_step(final, gt)) - def testLabelImage(self): - - image_filename = ('../label_image/data/grace_hopper.jpg') - - # Load some default data - label_path = os.path.join(tf.resource_loader.get_data_files_path(), - 'data/labels.txt') - labels = label_image.load_labels(label_path) - self.assertEqual(len(labels), 3) - - image_path = os.path.join(tf.resource_loader.get_data_files_path(), - image_filename) - - image = label_image.load_image(image_path) - self.assertEqual(len(image), 61306) - - # Create trivial graph; note that the two nodes don't meet - with tf.Graph().as_default(): - jpeg = tf.constant(image) - # Input node that doesn't lead anywhere. - tf.image.decode_jpeg(jpeg, name='DecodeJpeg') - - # Output node, that always outputs a constant. - final = tf.constant([[10, 30, 5]], name='final') - - # As label_image outputs via print, we assume that - # if it returns, everything is OK. - result = label_image.run_graph(image, labels, jpeg, final) - self.assertEqual(result, 0) - if __name__ == '__main__': tf.test.main() |