diff options
Diffstat (limited to 'tensorflow/contrib/eager/python/examples/revnet/cifar_input.py')
-rw-r--r-- | tensorflow/contrib/eager/python/examples/revnet/cifar_input.py | 35 |
1 files changed, 23 insertions, 12 deletions
diff --git a/tensorflow/contrib/eager/python/examples/revnet/cifar_input.py b/tensorflow/contrib/eager/python/examples/revnet/cifar_input.py index 3bc69da5ad..e1d8b3a055 100644 --- a/tensorflow/contrib/eager/python/examples/revnet/cifar_input.py +++ b/tensorflow/contrib/eager/python/examples/revnet/cifar_input.py @@ -26,8 +26,6 @@ import tensorflow as tf IMAGE_HEIGHT = 32 IMAGE_WIDTH = 32 NUM_CHANNEL = 3 -NUM_TRAIN_IMG = 50000 -NUM_TEST_IMG = 10000 def get_ds_from_tfrecords(data_dir, @@ -37,8 +35,8 @@ def get_ds_from_tfrecords(data_dir, epochs=None, shuffle=True, data_format="channels_first", - num_parallel_calls=4, - prefetch=True, + num_parallel_calls=8, + prefetch=0, div255=True, dtype=tf.float32): """Returns a tf.train.Dataset object from reading tfrecords. @@ -48,11 +46,12 @@ def get_ds_from_tfrecords(data_dir, split: "train", "validation", or "test" data_aug: Apply data augmentation if True batch_size: Batch size of dataset object - epochs: Number of epochs to repeat the dataset + epochs: Number of epochs to repeat the dataset; default `None` means + repeating indefinitely shuffle: Shuffle the dataset if True data_format: `channels_first` or `channels_last` num_parallel_calls: Number of threads for dataset preprocess - prefetch: Apply prefetch for the dataset if True + prefetch: Buffer size for prefetch div255: Divide the images by 255 if True dtype: Data type of images Returns: @@ -62,7 +61,7 @@ def get_ds_from_tfrecords(data_dir, ValueError: Unknown split """ - if split not in ["train", "validation", "test"]: + if split not in ["train", "validation", "test", "train_all"]: raise ValueError("Unknown split {}".format(split)) def _parser(serialized_example): @@ -74,7 +73,11 @@ def get_ds_from_tfrecords(data_dir, "label": tf.FixedLenFeature([], tf.int64), }) image = tf.decode_raw(features["image"], tf.uint8) - image = tf.reshape(image, [IMAGE_HEIGHT, IMAGE_WIDTH, NUM_CHANNEL]) + # Initially reshaping to [H, W, C] does not work + image = tf.reshape(image, [NUM_CHANNEL, IMAGE_HEIGHT, IMAGE_WIDTH]) + # This is needed for `tf.image.resize_image_with_crop_or_pad` + image = tf.transpose(image, [1, 2, 0]) + image = tf.cast(image, dtype) label = tf.cast(features["label"], tf.int32) @@ -93,13 +96,21 @@ def get_ds_from_tfrecords(data_dir, return image, label filename = os.path.join(data_dir, split + ".tfrecords") - dataset = tf.data.TFRecordDataset(filename).repeat(epochs) + dataset = tf.data.TFRecordDataset(filename) + dataset = dataset.repeat(epochs) dataset = dataset.map(_parser, num_parallel_calls=num_parallel_calls) + dataset = dataset.prefetch(prefetch) - if prefetch: - dataset = dataset.prefetch(batch_size) if shuffle: - dataset = dataset.shuffle(NUM_TRAIN_IMG) + # Find the right size according to the split + size = { + "train": 40000, + "validation": 10000, + "test": 10000, + "train_all": 50000 + }[split] + dataset = dataset.shuffle(size) + dataset = dataset.batch(batch_size) return dataset |