aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/examples/image_retraining
diff options
context:
space:
mode:
authorGravatar Dan Mané <danmane@google.com>2016-08-11 11:20:02 -0800
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2016-08-11 12:31:38 -0700
commitd9527a0a622cb850f6ef2259f40bbc3c84a8a475 (patch)
treeeec27205fb0503acfe35440502b74c559d1d818e /tensorflow/examples/image_retraining
parentc5dfa2e4231dd5acd2ffb5f54d22e907edadf61e (diff)
Automated rollback of change 129807750
Change: 130017620
Diffstat (limited to 'tensorflow/examples/image_retraining')
-rw-r--r--tensorflow/examples/image_retraining/BUILD9
-rw-r--r--tensorflow/examples/image_retraining/retrain.py18
-rw-r--r--tensorflow/examples/image_retraining/retrain_test.py32
3 files changed, 11 insertions, 48 deletions
diff --git a/tensorflow/examples/image_retraining/BUILD b/tensorflow/examples/image_retraining/BUILD
index b1f83f76d5..4cf6adecb9 100644
--- a/tensorflow/examples/image_retraining/BUILD
+++ b/tensorflow/examples/image_retraining/BUILD
@@ -21,14 +21,9 @@ py_test(
name = "retrain_test",
size = "small",
srcs = [
- "label_image.py",
"retrain.py",
"retrain_test.py",
],
- data = [
- ":data/labels.txt",
- "//tensorflow/examples/label_image:data/grace_hopper.jpg",
- ],
srcs_version = "PY2AND3",
deps = [
"//tensorflow:tensorflow_py",
@@ -40,9 +35,7 @@ py_test(
filegroup(
name = "all_files",
srcs = glob(
- [
- "**/*",
- ],
+ ["**/*"],
exclude = [
"**/METADATA",
"**/OWNERS",
diff --git a/tensorflow/examples/image_retraining/retrain.py b/tensorflow/examples/image_retraining/retrain.py
index 8ca3c570d7..6a3024d5bc 100644
--- a/tensorflow/examples/image_retraining/retrain.py
+++ b/tensorflow/examples/image_retraining/retrain.py
@@ -79,6 +79,8 @@ import tensorflow as tf
from tensorflow.python.framework import graph_util
from tensorflow.python.framework import tensor_shape
+from tensorflow.python.platform import gfile
+
FLAGS = tf.app.flags.FLAGS
@@ -178,7 +180,7 @@ def create_image_lists(image_dir, testing_percentage, validation_percentage):
A dictionary containing an entry for each label subfolder, with images split
into training, testing, and validation sets within each label.
"""
- if not tf.gfile.Exists(image_dir):
+ if not gfile.Exists(image_dir):
print("Image directory '" + image_dir + "' not found.")
return None
result = {}
@@ -301,7 +303,7 @@ def create_inception_graph():
with tf.Session() as sess:
model_filename = os.path.join(
FLAGS.model_dir, 'classify_image_graph_def.pb')
- with tf.gfile.FastGFile(model_filename, 'rb') as f:
+ with gfile.FastGFile(model_filename, 'rb') as f:
graph_def = tf.GraphDef()
graph_def.ParseFromString(f.read())
bottleneck_tensor, jpeg_data_tensor, resized_input_tensor = (
@@ -404,9 +406,9 @@ def get_or_create_bottleneck(sess, image_lists, label_name, index, image_dir,
print('Creating bottleneck at ' + bottleneck_path)
image_path = get_image_path(image_lists, label_name, index, image_dir,
category)
- if not tf.gfile.Exists(image_path):
+ if not gfile.Exists(image_path):
tf.logging.fatal('File does not exist %s', image_path)
- image_data = tf.gfile.FastGFile(image_path, 'rb').read()
+ image_data = gfile.FastGFile(image_path, 'rb').read()
bottleneck_values = run_bottleneck_on_image(sess, image_data,
jpeg_data_tensor,
bottleneck_tensor)
@@ -535,9 +537,9 @@ def get_random_distorted_bottlenecks(
image_index = random.randrange(65536)
image_path = get_image_path(image_lists, label_name, image_index, image_dir,
category)
- if not tf.gfile.Exists(image_path):
+ if not gfile.Exists(image_path):
tf.logging.fatal('File does not exist %s', image_path)
- jpeg_data = tf.gfile.FastGFile(image_path, 'rb').read()
+ jpeg_data = gfile.FastGFile(image_path, 'rb').read()
# Note that we materialize the distorted_image_data as a numpy array before
# sending running inference on the image. This involves 2 memory copies and
# might be optimized in other implementations.
@@ -876,9 +878,9 @@ def main(_):
# Write out the trained graph and labels with the weights stored as constants.
output_graph_def = graph_util.convert_variables_to_constants(
sess, graph.as_graph_def(), [FLAGS.final_tensor_name])
- with tf.gfile.FastGFile(FLAGS.output_graph, 'wb') as f:
+ with gfile.FastGFile(FLAGS.output_graph, 'wb') as f:
f.write(output_graph_def.SerializeToString())
- with tf.gfile.FastGFile(FLAGS.output_labels, 'w') as f:
+ with gfile.FastGFile(FLAGS.output_labels, 'w') as f:
f.write('\n'.join(image_lists.keys()) + '\n')
diff --git a/tensorflow/examples/image_retraining/retrain_test.py b/tensorflow/examples/image_retraining/retrain_test.py
index fb9acea3eb..072998ae60 100644
--- a/tensorflow/examples/image_retraining/retrain_test.py
+++ b/tensorflow/examples/image_retraining/retrain_test.py
@@ -19,9 +19,7 @@ from __future__ import division
from __future__ import print_function
import tensorflow as tf
-import os
-from tensorflow.examples.image_retraining import label_image
from tensorflow.examples.image_retraining import retrain
from tensorflow.python.framework import test_util
@@ -82,35 +80,5 @@ class ImageRetrainingTest(test_util.TensorFlowTestCase):
gt = tf.placeholder(tf.float32, [1], name='gt')
self.assertIsNotNone(retrain.add_evaluation_step(final, gt))
- def testLabelImage(self):
-
- image_filename = ('../label_image/data/grace_hopper.jpg')
-
- # Load some default data
- label_path = os.path.join(tf.resource_loader.get_data_files_path(),
- 'data/labels.txt')
- labels = label_image.load_labels(label_path)
- self.assertEqual(len(labels), 3)
-
- image_path = os.path.join(tf.resource_loader.get_data_files_path(),
- image_filename)
-
- image = label_image.load_image(image_path)
- self.assertEqual(len(image), 61306)
-
- # Create trivial graph; note that the two nodes don't meet
- with tf.Graph().as_default():
- jpeg = tf.constant(image)
- # Input node that doesn't lead anywhere.
- tf.image.decode_jpeg(jpeg, name='DecodeJpeg')
-
- # Output node, that always outputs a constant.
- final = tf.constant([[10, 30, 5]], name='final')
-
- # As label_image outputs via print, we assume that
- # if it returns, everything is OK.
- result = label_image.run_graph(image, labels, jpeg, final)
- self.assertEqual(result, 0)
-
if __name__ == '__main__':
tf.test.main()