aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/contrib/eager/python/examples/revnet/cifar_tfrecords.py
diff options
context:
space:
mode:
Diffstat (limited to 'tensorflow/contrib/eager/python/examples/revnet/cifar_tfrecords.py')
-rw-r--r--tensorflow/contrib/eager/python/examples/revnet/cifar_tfrecords.py89
1 files changed, 60 insertions, 29 deletions
diff --git a/tensorflow/contrib/eager/python/examples/revnet/cifar_tfrecords.py b/tensorflow/contrib/eager/python/examples/revnet/cifar_tfrecords.py
index f79428b2a9..377844ad8f 100644
--- a/tensorflow/contrib/eager/python/examples/revnet/cifar_tfrecords.py
+++ b/tensorflow/contrib/eager/python/examples/revnet/cifar_tfrecords.py
@@ -12,10 +12,10 @@
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
-"""Read CIFAR-10 data from pickled numpy arrays and writes TFRecords.
+"""Read CIFAR data from pickled numpy arrays and writes TFRecords.
Generates tf.train.Example protos and writes them to TFRecord files from the
-python version of the CIFAR-10 dataset downloaded from
+python version of the CIFAR dataset downloaded from
https://www.cs.toronto.edu/~kriz/cifar.html.
"""
@@ -32,20 +32,22 @@ from six.moves import cPickle as pickle
from six.moves import urllib
import tensorflow as tf
-CIFAR_FILENAME = 'cifar-10-python.tar.gz'
-CIFAR_DOWNLOAD_URL = 'https://www.cs.toronto.edu/~kriz/' + CIFAR_FILENAME
-CIFAR_LOCAL_FOLDER = 'cifar-10-batches-py'
+BASE_URL = 'https://www.cs.toronto.edu/~kriz/'
+CIFAR_FILE_NAMES = ['cifar-10-python.tar.gz', 'cifar-100-python.tar.gz']
+CIFAR_DOWNLOAD_URLS = [BASE_URL + name for name in CIFAR_FILE_NAMES]
+CIFAR_LOCAL_FOLDERS = ['cifar-10', 'cifar-100']
+EXTRACT_FOLDERS = ['cifar-10-batches-py', 'cifar-100-python']
-def download_and_extract(data_dir):
- """Download CIFAR-10 if not already downloaded."""
- filepath = os.path.join(data_dir, CIFAR_FILENAME)
+def download_and_extract(data_dir, file_name, url):
+ """Download CIFAR if not already downloaded."""
+ filepath = os.path.join(data_dir, file_name)
if tf.gfile.Exists(filepath):
return filepath
if not tf.gfile.Exists(data_dir):
tf.gfile.MakeDirs(data_dir)
- urllib.request.urlretrieve(CIFAR_DOWNLOAD_URL, filepath)
+ urllib.request.urlretrieve(url, filepath)
tarfile.open(os.path.join(filepath), 'r:gz').extractall(data_dir)
return filepath
@@ -58,12 +60,22 @@ def _bytes_feature(value):
return tf.train.Feature(bytes_list=tf.train.BytesList(value=[value]))
-def _get_file_names():
+def _get_file_names(folder):
"""Returns the file names expected to exist in the input_dir."""
+ assert folder in ['cifar-10', 'cifar-100']
+
file_names = {}
- file_names['train'] = ['data_batch_%d' % i for i in range(1, 5)]
- file_names['validation'] = ['data_batch_5']
- file_names['test'] = ['test_batch']
+ if folder == 'cifar-10':
+ file_names['train'] = ['data_batch_%d' % i for i in range(1, 5)]
+ file_names['validation'] = ['data_batch_5']
+ file_names['train_all'] = ['data_batch_%d' % i for i in range(1, 6)]
+ file_names['test'] = ['test_batch']
+ else:
+ file_names['train_all'] = ['train']
+ file_names['test'] = ['test']
+ # Split in `convert_to_tfrecord` function
+ file_names['train'] = ['train']
+ file_names['validation'] = ['train']
return file_names
@@ -76,14 +88,28 @@ def read_pickle_from_file(filename):
return data_dict
-def convert_to_tfrecord(input_files, output_file):
+def convert_to_tfrecord(input_files, output_file, folder):
"""Converts files with pickled data to TFRecords."""
+ assert folder in ['cifar-10', 'cifar-100']
+
print('Generating %s' % output_file)
with tf.python_io.TFRecordWriter(output_file) as record_writer:
for input_file in input_files:
data_dict = read_pickle_from_file(input_file)
data = data_dict[b'data']
- labels = data_dict[b'labels']
+ try:
+ labels = data_dict[b'labels']
+ except KeyError:
+ labels = data_dict[b'fine_labels']
+
+ if folder == 'cifar-100' and input_file.endswith('train.tfrecords'):
+ data = data[:40000]
+ labels = labels[:40000]
+ elif folder == 'cifar-100' and input_file.endswith(
+ 'validation.tfrecords'):
+ data = data[40000:]
+ labels = labels[40000:]
+
num_entries_in_batch = len(labels)
for i in range(num_entries_in_batch):
@@ -97,19 +123,24 @@ def convert_to_tfrecord(input_files, output_file):
def main(_):
- print('Download from {} and extract.'.format(CIFAR_DOWNLOAD_URL))
- download_and_extract(FLAGS.data_dir)
- file_names = _get_file_names()
- input_dir = os.path.join(FLAGS.data_dir, CIFAR_LOCAL_FOLDER)
-
- for mode, files in file_names.items():
- input_files = [os.path.join(input_dir, f) for f in files]
- output_file = os.path.join(FLAGS.data_dir, mode + '.tfrecords')
- try:
- os.remove(output_file)
- except OSError:
- pass
- convert_to_tfrecord(input_files, output_file)
+ for file_name, url, folder, extract_folder in zip(
+ CIFAR_FILE_NAMES, CIFAR_DOWNLOAD_URLS, CIFAR_LOCAL_FOLDERS,
+ EXTRACT_FOLDERS):
+ print('Download from {} and extract.'.format(url))
+ data_dir = os.path.join(FLAGS.data_dir, folder)
+ download_and_extract(data_dir, file_name, url)
+ file_names = _get_file_names(folder)
+ input_dir = os.path.join(data_dir, extract_folder)
+
+ for mode, files in file_names.items():
+ input_files = [os.path.join(input_dir, f) for f in files]
+ output_file = os.path.join(data_dir, mode + '.tfrecords')
+ try:
+ os.remove(output_file)
+ except OSError:
+ pass
+ convert_to_tfrecord(input_files, output_file, folder)
+
print('Done!')
@@ -118,6 +149,6 @@ if __name__ == '__main__':
flags.DEFINE_string(
'data_dir',
default=None,
- help='Directory to download and extract CIFAR-10 to.')
+ help='Directory to download, extract and store TFRecords.')
tf.app.run(main)