diff options
Diffstat (limited to 'tensorflow/contrib/eager/python/examples/revnet/cifar_tfrecords.py')
-rw-r--r-- | tensorflow/contrib/eager/python/examples/revnet/cifar_tfrecords.py | 89 |
1 files changed, 60 insertions, 29 deletions
diff --git a/tensorflow/contrib/eager/python/examples/revnet/cifar_tfrecords.py b/tensorflow/contrib/eager/python/examples/revnet/cifar_tfrecords.py index f79428b2a9..377844ad8f 100644 --- a/tensorflow/contrib/eager/python/examples/revnet/cifar_tfrecords.py +++ b/tensorflow/contrib/eager/python/examples/revnet/cifar_tfrecords.py @@ -12,10 +12,10 @@ # See the License for the specific language governing permissions and # limitations under the License. # ============================================================================== -"""Read CIFAR-10 data from pickled numpy arrays and writes TFRecords. +"""Read CIFAR data from pickled numpy arrays and writes TFRecords. Generates tf.train.Example protos and writes them to TFRecord files from the -python version of the CIFAR-10 dataset downloaded from +python version of the CIFAR dataset downloaded from https://www.cs.toronto.edu/~kriz/cifar.html. """ @@ -32,20 +32,22 @@ from six.moves import cPickle as pickle from six.moves import urllib import tensorflow as tf -CIFAR_FILENAME = 'cifar-10-python.tar.gz' -CIFAR_DOWNLOAD_URL = 'https://www.cs.toronto.edu/~kriz/' + CIFAR_FILENAME -CIFAR_LOCAL_FOLDER = 'cifar-10-batches-py' +BASE_URL = 'https://www.cs.toronto.edu/~kriz/' +CIFAR_FILE_NAMES = ['cifar-10-python.tar.gz', 'cifar-100-python.tar.gz'] +CIFAR_DOWNLOAD_URLS = [BASE_URL + name for name in CIFAR_FILE_NAMES] +CIFAR_LOCAL_FOLDERS = ['cifar-10', 'cifar-100'] +EXTRACT_FOLDERS = ['cifar-10-batches-py', 'cifar-100-python'] -def download_and_extract(data_dir): - """Download CIFAR-10 if not already downloaded.""" - filepath = os.path.join(data_dir, CIFAR_FILENAME) +def download_and_extract(data_dir, file_name, url): + """Download CIFAR if not already downloaded.""" + filepath = os.path.join(data_dir, file_name) if tf.gfile.Exists(filepath): return filepath if not tf.gfile.Exists(data_dir): tf.gfile.MakeDirs(data_dir) - urllib.request.urlretrieve(CIFAR_DOWNLOAD_URL, filepath) + urllib.request.urlretrieve(url, filepath) tarfile.open(os.path.join(filepath), 'r:gz').extractall(data_dir) return filepath @@ -58,12 +60,22 @@ def _bytes_feature(value): return tf.train.Feature(bytes_list=tf.train.BytesList(value=[value])) -def _get_file_names(): +def _get_file_names(folder): """Returns the file names expected to exist in the input_dir.""" + assert folder in ['cifar-10', 'cifar-100'] + file_names = {} - file_names['train'] = ['data_batch_%d' % i for i in range(1, 5)] - file_names['validation'] = ['data_batch_5'] - file_names['test'] = ['test_batch'] + if folder == 'cifar-10': + file_names['train'] = ['data_batch_%d' % i for i in range(1, 5)] + file_names['validation'] = ['data_batch_5'] + file_names['train_all'] = ['data_batch_%d' % i for i in range(1, 6)] + file_names['test'] = ['test_batch'] + else: + file_names['train_all'] = ['train'] + file_names['test'] = ['test'] + # Split in `convert_to_tfrecord` function + file_names['train'] = ['train'] + file_names['validation'] = ['train'] return file_names @@ -76,14 +88,28 @@ def read_pickle_from_file(filename): return data_dict -def convert_to_tfrecord(input_files, output_file): +def convert_to_tfrecord(input_files, output_file, folder): """Converts files with pickled data to TFRecords.""" + assert folder in ['cifar-10', 'cifar-100'] + print('Generating %s' % output_file) with tf.python_io.TFRecordWriter(output_file) as record_writer: for input_file in input_files: data_dict = read_pickle_from_file(input_file) data = data_dict[b'data'] - labels = data_dict[b'labels'] + try: + labels = data_dict[b'labels'] + except KeyError: + labels = data_dict[b'fine_labels'] + + if folder == 'cifar-100' and input_file.endswith('train.tfrecords'): + data = data[:40000] + labels = labels[:40000] + elif folder == 'cifar-100' and input_file.endswith( + 'validation.tfrecords'): + data = data[40000:] + labels = labels[40000:] + num_entries_in_batch = len(labels) for i in range(num_entries_in_batch): @@ -97,19 +123,24 @@ def convert_to_tfrecord(input_files, output_file): def main(_): - print('Download from {} and extract.'.format(CIFAR_DOWNLOAD_URL)) - download_and_extract(FLAGS.data_dir) - file_names = _get_file_names() - input_dir = os.path.join(FLAGS.data_dir, CIFAR_LOCAL_FOLDER) - - for mode, files in file_names.items(): - input_files = [os.path.join(input_dir, f) for f in files] - output_file = os.path.join(FLAGS.data_dir, mode + '.tfrecords') - try: - os.remove(output_file) - except OSError: - pass - convert_to_tfrecord(input_files, output_file) + for file_name, url, folder, extract_folder in zip( + CIFAR_FILE_NAMES, CIFAR_DOWNLOAD_URLS, CIFAR_LOCAL_FOLDERS, + EXTRACT_FOLDERS): + print('Download from {} and extract.'.format(url)) + data_dir = os.path.join(FLAGS.data_dir, folder) + download_and_extract(data_dir, file_name, url) + file_names = _get_file_names(folder) + input_dir = os.path.join(data_dir, extract_folder) + + for mode, files in file_names.items(): + input_files = [os.path.join(input_dir, f) for f in files] + output_file = os.path.join(data_dir, mode + '.tfrecords') + try: + os.remove(output_file) + except OSError: + pass + convert_to_tfrecord(input_files, output_file, folder) + print('Done!') @@ -118,6 +149,6 @@ if __name__ == '__main__': flags.DEFINE_string( 'data_dir', default=None, - help='Directory to download and extract CIFAR-10 to.') + help='Directory to download, extract and store TFRecords.') tf.app.run(main) |