aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/models/image
diff options
context:
space:
mode:
authorGravatar A. Unique TensorFlower <nobody@tensorflow.org>2016-02-09 12:56:46 -0800
committerGravatar Vijay Vasudevan <vrv@google.com>2016-02-09 13:06:51 -0800
commit27bbe92711a93613eca843772b6e7eb32ff96c35 (patch)
tree04819ee6b65774ee1496bc62ea50d2cc057608df /tensorflow/models/image
parent3c13ae058ea45d855d8029b3d19f6567b86430b5 (diff)
Make the gfile package available when importing tensorflow.
Update programs that were importing both 'tensorflow' and 'gfile' to use 'gfile' from the tensorflow import. Change: 114249943
Diffstat (limited to 'tensorflow/models/image')
-rw-r--r--tensorflow/models/image/cifar10/cifar10_eval.py8
-rw-r--r--tensorflow/models/image/cifar10/cifar10_input.py6
-rw-r--r--tensorflow/models/image/cifar10/cifar10_multi_gpu_train.py8
-rw-r--r--tensorflow/models/image/cifar10/cifar10_train.py7
-rw-r--r--tensorflow/models/image/imagenet/classify_image.py18
-rw-r--r--tensorflow/models/image/mnist/convolutional.py11
6 files changed, 27 insertions, 31 deletions
diff --git a/tensorflow/models/image/cifar10/cifar10_eval.py b/tensorflow/models/image/cifar10/cifar10_eval.py
index 6dc1db7248..9ba89e4e1e 100644
--- a/tensorflow/models/image/cifar10/cifar10_eval.py
+++ b/tensorflow/models/image/cifar10/cifar10_eval.py
@@ -39,7 +39,7 @@ import math
import time
import tensorflow.python.platform
-from tensorflow.python.platform import gfile
+
import numpy as np
import tensorflow as tf
@@ -151,9 +151,9 @@ def evaluate():
def main(argv=None): # pylint: disable=unused-argument
cifar10.maybe_download_and_extract()
- if gfile.Exists(FLAGS.eval_dir):
- gfile.DeleteRecursively(FLAGS.eval_dir)
- gfile.MakeDirs(FLAGS.eval_dir)
+ if tf.gfile.Exists(FLAGS.eval_dir):
+ tf.gfile.DeleteRecursively(FLAGS.eval_dir)
+ tf.gfile.MakeDirs(FLAGS.eval_dir)
evaluate()
diff --git a/tensorflow/models/image/cifar10/cifar10_input.py b/tensorflow/models/image/cifar10/cifar10_input.py
index f7d7083d73..d5e12c08b9 100644
--- a/tensorflow/models/image/cifar10/cifar10_input.py
+++ b/tensorflow/models/image/cifar10/cifar10_input.py
@@ -25,8 +25,6 @@ import tensorflow.python.platform
from six.moves import xrange # pylint: disable=redefined-builtin
import tensorflow as tf
-from tensorflow.python.platform import gfile
-
# Process images of this size. Note that this differs from the original CIFAR
# image size of 32 x 32. If one alters this number, then the entire model
# architecture will change and any model would need to be retrained.
@@ -144,7 +142,7 @@ def distorted_inputs(data_dir, batch_size):
filenames = [os.path.join(data_dir, 'data_batch_%d.bin' % i)
for i in xrange(1, 6)]
for f in filenames:
- if not gfile.Exists(f):
+ if not tf.gfile.Exists(f):
raise ValueError('Failed to find file: ' + f)
# Create a queue that produces the filenames to read.
@@ -209,7 +207,7 @@ def inputs(eval_data, data_dir, batch_size):
num_examples_per_epoch = NUM_EXAMPLES_PER_EPOCH_FOR_EVAL
for f in filenames:
- if not gfile.Exists(f):
+ if not tf.gfile.Exists(f):
raise ValueError('Failed to find file: ' + f)
# Create a queue that produces the filenames to read.
diff --git a/tensorflow/models/image/cifar10/cifar10_multi_gpu_train.py b/tensorflow/models/image/cifar10/cifar10_multi_gpu_train.py
index f594b86627..b7c07435af 100644
--- a/tensorflow/models/image/cifar10/cifar10_multi_gpu_train.py
+++ b/tensorflow/models/image/cifar10/cifar10_multi_gpu_train.py
@@ -46,7 +46,7 @@ import time
# pylint: disable=unused-import,g-bad-import-order
import tensorflow.python.platform
-from tensorflow.python.platform import gfile
+
import numpy as np
from six.moves import xrange # pylint: disable=redefined-builtin
import tensorflow as tf
@@ -275,9 +275,9 @@ def train():
def main(argv=None): # pylint: disable=unused-argument
cifar10.maybe_download_and_extract()
- if gfile.Exists(FLAGS.train_dir):
- gfile.DeleteRecursively(FLAGS.train_dir)
- gfile.MakeDirs(FLAGS.train_dir)
+ if tf.gfile.Exists(FLAGS.train_dir):
+ tf.gfile.DeleteRecursively(FLAGS.train_dir)
+ tf.gfile.MakeDirs(FLAGS.train_dir)
train()
diff --git a/tensorflow/models/image/cifar10/cifar10_train.py b/tensorflow/models/image/cifar10/cifar10_train.py
index fb2ef56e1e..1882f256bd 100644
--- a/tensorflow/models/image/cifar10/cifar10_train.py
+++ b/tensorflow/models/image/cifar10/cifar10_train.py
@@ -41,7 +41,6 @@ import os.path
import time
import tensorflow.python.platform
-from tensorflow.python.platform import gfile
import numpy as np
from six.moves import xrange # pylint: disable=redefined-builtin
@@ -128,9 +127,9 @@ def train():
def main(argv=None): # pylint: disable=unused-argument
cifar10.maybe_download_and_extract()
- if gfile.Exists(FLAGS.train_dir):
- gfile.DeleteRecursively(FLAGS.train_dir)
- gfile.MakeDirs(FLAGS.train_dir)
+ if tf.gfile.Exists(FLAGS.train_dir):
+ tf.gfile.DeleteRecursively(FLAGS.train_dir)
+ tf.gfile.MakeDirs(FLAGS.train_dir)
train()
diff --git a/tensorflow/models/image/imagenet/classify_image.py b/tensorflow/models/image/imagenet/classify_image.py
index 2459f8a633..838cc568ba 100644
--- a/tensorflow/models/image/imagenet/classify_image.py
+++ b/tensorflow/models/image/imagenet/classify_image.py
@@ -47,8 +47,6 @@ import numpy as np
import tensorflow as tf
# pylint: enable=unused-import,g-bad-import-order
-from tensorflow.python.platform import gfile
-
FLAGS = tf.app.flags.FLAGS
# classify_image_graph_def.pb:
@@ -96,13 +94,13 @@ class NodeLookup(object):
Returns:
dict from integer node ID to human-readable string.
"""
- if not gfile.Exists(uid_lookup_path):
+ if not tf.gfile.Exists(uid_lookup_path):
tf.logging.fatal('File does not exist %s', uid_lookup_path)
- if not gfile.Exists(label_lookup_path):
+ if not tf.gfile.Exists(label_lookup_path):
tf.logging.fatal('File does not exist %s', label_lookup_path)
# Loads mapping from string UID to human-readable string
- proto_as_ascii_lines = gfile.GFile(uid_lookup_path).readlines()
+ proto_as_ascii_lines = tf.gfile.GFile(uid_lookup_path).readlines()
uid_to_human = {}
p = re.compile(r'[n\d]*[ \S,]*')
for line in proto_as_ascii_lines:
@@ -113,7 +111,7 @@ class NodeLookup(object):
# Loads mapping from string UID to integer node ID.
node_id_to_uid = {}
- proto_as_ascii = gfile.GFile(label_lookup_path).readlines()
+ proto_as_ascii = tf.gfile.GFile(label_lookup_path).readlines()
for line in proto_as_ascii:
if line.startswith(' target_class:'):
target_class = int(line.split(': ')[1])
@@ -138,9 +136,9 @@ class NodeLookup(object):
def create_graph():
- """"Creates a graph from saved GraphDef file and returns a saver."""
+ """Creates a graph from saved GraphDef file and returns a saver."""
# Creates graph from saved graph_def.pb.
- with gfile.FastGFile(os.path.join(
+ with tf.gfile.FastGFile(os.path.join(
FLAGS.model_dir, 'classify_image_graph_def.pb'), 'rb') as f:
graph_def = tf.GraphDef()
graph_def.ParseFromString(f.read())
@@ -156,9 +154,9 @@ def run_inference_on_image(image):
Returns:
Nothing
"""
- if not gfile.Exists(image):
+ if not tf.gfile.Exists(image):
tf.logging.fatal('File does not exist %s', image)
- image_data = gfile.FastGFile(image, 'rb').read()
+ image_data = tf.gfile.FastGFile(image, 'rb').read()
# Creates graph from saved GraphDef.
create_graph()
diff --git a/tensorflow/models/image/mnist/convolutional.py b/tensorflow/models/image/mnist/convolutional.py
index edceb2a1ec..c0dfcc7979 100644
--- a/tensorflow/models/image/mnist/convolutional.py
+++ b/tensorflow/models/image/mnist/convolutional.py
@@ -55,13 +55,14 @@ FLAGS = tf.app.flags.FLAGS
def maybe_download(filename):
"""Download the data from Yann's website, unless it's already here."""
- if not os.path.exists(WORK_DIRECTORY):
- os.mkdir(WORK_DIRECTORY)
+ if not tf.gfile.Exists(WORK_DIRECTORY):
+ tf.gfile.MakeDirs(WORK_DIRECTORY)
filepath = os.path.join(WORK_DIRECTORY, filename)
- if not os.path.exists(filepath):
+ if not tf.gfile.Exists(filepath):
filepath, _ = urllib.request.urlretrieve(SOURCE_URL + filename, filepath)
- statinfo = os.stat(filepath)
- print('Successfully downloaded', filename, statinfo.st_size, 'bytes.')
+ with tf.gfile.GFile(filepath) as f:
+ size = f.Size()
+ print('Successfully downloaded', filename, size, 'bytes.')
return filepath