aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/contrib/eager/python/examples/revnet/cifar_input.py
diff options
context:
space:
mode:
Diffstat (limited to 'tensorflow/contrib/eager/python/examples/revnet/cifar_input.py')
-rw-r--r--tensorflow/contrib/eager/python/examples/revnet/cifar_input.py35
1 files changed, 23 insertions, 12 deletions
diff --git a/tensorflow/contrib/eager/python/examples/revnet/cifar_input.py b/tensorflow/contrib/eager/python/examples/revnet/cifar_input.py
index 3bc69da5ad..e1d8b3a055 100644
--- a/tensorflow/contrib/eager/python/examples/revnet/cifar_input.py
+++ b/tensorflow/contrib/eager/python/examples/revnet/cifar_input.py
@@ -26,8 +26,6 @@ import tensorflow as tf
IMAGE_HEIGHT = 32
IMAGE_WIDTH = 32
NUM_CHANNEL = 3
-NUM_TRAIN_IMG = 50000
-NUM_TEST_IMG = 10000
def get_ds_from_tfrecords(data_dir,
@@ -37,8 +35,8 @@ def get_ds_from_tfrecords(data_dir,
epochs=None,
shuffle=True,
data_format="channels_first",
- num_parallel_calls=4,
- prefetch=True,
+ num_parallel_calls=8,
+ prefetch=0,
div255=True,
dtype=tf.float32):
"""Returns a tf.train.Dataset object from reading tfrecords.
@@ -48,11 +46,12 @@ def get_ds_from_tfrecords(data_dir,
split: "train", "validation", or "test"
data_aug: Apply data augmentation if True
batch_size: Batch size of dataset object
- epochs: Number of epochs to repeat the dataset
+ epochs: Number of epochs to repeat the dataset; default `None` means
+ repeating indefinitely
shuffle: Shuffle the dataset if True
data_format: `channels_first` or `channels_last`
num_parallel_calls: Number of threads for dataset preprocess
- prefetch: Apply prefetch for the dataset if True
+ prefetch: Buffer size for prefetch
div255: Divide the images by 255 if True
dtype: Data type of images
Returns:
@@ -62,7 +61,7 @@ def get_ds_from_tfrecords(data_dir,
ValueError: Unknown split
"""
- if split not in ["train", "validation", "test"]:
+ if split not in ["train", "validation", "test", "train_all"]:
raise ValueError("Unknown split {}".format(split))
def _parser(serialized_example):
@@ -74,7 +73,11 @@ def get_ds_from_tfrecords(data_dir,
"label": tf.FixedLenFeature([], tf.int64),
})
image = tf.decode_raw(features["image"], tf.uint8)
- image = tf.reshape(image, [IMAGE_HEIGHT, IMAGE_WIDTH, NUM_CHANNEL])
+ # Initially reshaping to [H, W, C] does not work
+ image = tf.reshape(image, [NUM_CHANNEL, IMAGE_HEIGHT, IMAGE_WIDTH])
+ # This is needed for `tf.image.resize_image_with_crop_or_pad`
+ image = tf.transpose(image, [1, 2, 0])
+
image = tf.cast(image, dtype)
label = tf.cast(features["label"], tf.int32)
@@ -93,13 +96,21 @@ def get_ds_from_tfrecords(data_dir,
return image, label
filename = os.path.join(data_dir, split + ".tfrecords")
- dataset = tf.data.TFRecordDataset(filename).repeat(epochs)
+ dataset = tf.data.TFRecordDataset(filename)
+ dataset = dataset.repeat(epochs)
dataset = dataset.map(_parser, num_parallel_calls=num_parallel_calls)
+ dataset = dataset.prefetch(prefetch)
- if prefetch:
- dataset = dataset.prefetch(batch_size)
if shuffle:
- dataset = dataset.shuffle(NUM_TRAIN_IMG)
+ # Find the right size according to the split
+ size = {
+ "train": 40000,
+ "validation": 10000,
+ "test": 10000,
+ "train_all": 50000
+ }[split]
+ dataset = dataset.shuffle(size)
+
dataset = dataset.batch(batch_size)
return dataset