aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow
diff options
context:
space:
mode:
authorGravatar A. Unique TensorFlower <gardener@tensorflow.org>2016-10-17 08:51:36 -0800
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2016-10-17 10:08:42 -0700
commit83f2fe11b30c4111084f1d99f23396e8d01b45b7 (patch)
tree162a1511229b78f7f93dd874c3cf018ac133bfad /tensorflow
parent090e0743904a45bc3169b006ad90c7a4720f0998 (diff)
Remove `stratified_sample_unknown_dist`, since `stratified_sample` supports an unknown data distribution.
Change: 136361350
Diffstat (limited to 'tensorflow')
-rw-r--r--tensorflow/contrib/training/__init__.py13
-rw-r--r--tensorflow/contrib/training/python/training/sampling_ops.py166
-rw-r--r--tensorflow/contrib/training/python/training/sampling_ops_test.py196
3 files changed, 58 insertions, 317 deletions
diff --git a/tensorflow/contrib/training/__init__.py b/tensorflow/contrib/training/__init__.py
index 1290854260..c9564fc316 100644
--- a/tensorflow/contrib/training/__init__.py
+++ b/tensorflow/contrib/training/__init__.py
@@ -37,18 +37,15 @@ providing the desired rate for each example. If you wish to specify relative
rates, rather than absolute ones, use ['weighted_resample'](#weighted_resample)
(which also returns the actual resampling rate used for each output example).
-Use ['stratified_sample'](#stratified_sample) or
-['stratified_sample_unknown_dist'](#stratified_sample_unknown_dist) to
-resample without replacement from the data to achieve a desired mix of
-class proportions that the Tensorflow graph sees. For instance, if you
-have a binary classification dataset that is 99.9% class 1, a common
-approach is to resample from the data so that the data is more
-balanced.
+Use ['stratified_sample'](#stratified_sample) to resample without replacement
+from the data to achieve a desired mix of class proportions that the Tensorflow
+graph sees. For instance, if you have a binary classification dataset that is
+99.9% class 1, a common approach is to resample from the data so that the data
+is more balanced.
@@rejection_sample
@@resample_at_rate
@@stratified_sample
-@@stratified_sample_unknown_dist
@@weighted_resample
## Bucketing
diff --git a/tensorflow/contrib/training/python/training/sampling_ops.py b/tensorflow/contrib/training/python/training/sampling_ops.py
index 05f5ec6b39..395840d13e 100644
--- a/tensorflow/contrib/training/python/training/sampling_ops.py
+++ b/tensorflow/contrib/training/python/training/sampling_ops.py
@@ -34,8 +34,7 @@ from tensorflow.python.training import queue_runner
__all__ = ['rejection_sample',
- 'stratified_sample',
- 'stratified_sample_unknown_dist',]
+ 'stratified_sample',]
def rejection_sample(tensors, accept_prob_fn, batch_size, queue_threads=1,
@@ -124,9 +123,7 @@ def stratified_sample(tensors, labels, target_probs, batch_size,
This method discards examples. Internally, it creates one queue to amortize
the cost of disk reads, and one queue to hold the properly-proportioned
- batch. See `stratified_sample_unknown_dist` for a function that performs
- stratified sampling with one queue per class and doesn't require knowing the
- class data-distribution ahead of time.
+ batch.
Args:
tensors: List of tensors for data. All tensors are either one item or a
@@ -236,80 +233,6 @@ def stratified_sample(tensors, labels, target_probs, batch_size,
return batched[:-1], batched[-1]
-def stratified_sample_unknown_dist(tensors, labels, probs, batch_size,
- enqueue_many=False, queue_capacity=16,
- threads_per_queue=1, name=None):
- """Stochastically creates batches based on per-class probabilities.
-
- **NOTICE** This sampler can be significantly slower than `stratified_sample`
- due to each thread discarding all examples not in its assigned class.
-
- This uses a number of threads proportional to the number of classes. See
- `stratified_sample` for an implementation that discards fewer examples and
- uses a fixed number of threads. This function's only advantage over
- `stratified_sample` is that the class data-distribution doesn't need to be
- known ahead of time.
-
- Args:
- tensors: List of tensors for data. All tensors are either one item or a
- batch, according to enqueue_many.
- labels: Tensor for label of data. Label is a single integer or a batch,
- depending on enqueue_many. It is not a one-hot vector.
- probs: Target class probabilities. An object whose type has a registered
- Tensor conversion function.
- batch_size: Size of batch to be returned.
- enqueue_many: Bool. If true, interpret input tensors as having a batch
- dimension.
- queue_capacity: Capacity of each per-class queue.
- threads_per_queue: Number of threads for each per-class queue.
- name: Optional prefix for ops created by this function.
- Raises:
- ValueError: enqueue_many is True and labels doesn't have a batch
- dimension, or if enqueue_many is False and labels isn't a scalar.
- ValueError: enqueue_many is True, and batch dimension of data and labels
- don't match.
- ValueError: if probs don't sum to one.
- TFAssertion: if labels aren't integers in [0, num classes).
- Returns:
- (data_batch, label_batch), where data_batch is a list of tensors of the same
- length as `tensors`
-
- Example:
- # Get tensor for a single data and label example.
- data, label = data_provider.Get(['data', 'label'])
-
- # Get stratified batch according to per-class probabilities.
- init_probs = [1.0/NUM_CLASSES for _ in range(NUM_CLASSES)]
- [data_batch], labels = (
- tf.contrib.training.stratified_sample_unknown_dist(
- [data], label, init_probs, 16))
-
- # Run batch through network.
- ...
- """
- with ops.name_scope(name, 'stratified_sample_unknown_dist',
- tensors + [labels]):
- tensor_list = ops.convert_n_to_tensor_or_indexed_slices(tensors)
- labels = ops.convert_to_tensor(labels)
- probs = ops.convert_to_tensor(probs, dtype=dtypes.float32)
- # Reduce the case of a single example to that of a batch of size 1.
- if not enqueue_many:
- tensor_list = [array_ops.expand_dims(tensor, 0) for tensor in tensor_list]
- labels = array_ops.expand_dims(labels, 0)
-
- # Validate that input is consistent.
- tensor_list, labels, [probs] = _verify_input(tensor_list, labels, [probs])
-
- # Make per-class queues.
- per_class_queues = _make_per_class_queues(
- tensor_list, labels, probs.get_shape().num_elements(), queue_capacity,
- threads_per_queue)
-
- # Use the per-class queues to generate stratified batches.
- return _get_batch_from_per_class_queues(
- per_class_queues, probs, batch_size)
-
-
def _estimate_data_distribution(labels, num_classes, smoothing_constant=10):
"""Estimate data distribution as labels are seen."""
# Variable to track running count of classes. Smooth by a nonzero value to
@@ -521,88 +444,3 @@ def _conditional_batch(tensors, keep_input, batch_size, num_threads=10):
out_tensor = [out_tensor]
return out_tensor
-
-
-def _make_per_class_queues(tensor_list, labels, num_classes, queue_capacity,
- threads_per_queue):
- """Creates per-class-queues based on data and labels."""
- # Create one queue per class.
- queues = []
- data_shapes = []
- data_dtypes = []
- for data_tensor in tensor_list:
- per_data_shape = data_tensor.get_shape().with_rank_at_least(1)[1:]
- per_data_shape.assert_is_fully_defined()
- data_shapes.append(per_data_shape)
- data_dtypes.append(data_tensor.dtype)
-
- for i in range(num_classes):
- q = data_flow_ops.FIFOQueue(capacity=queue_capacity,
- shapes=data_shapes, dtypes=data_dtypes,
- name='stratified_sample_class%d_queue' % i)
- logging_ops.scalar_summary(
- 'queue/%s/stratified_sample_class%d' % (q.name, i), q.size())
- queues.append(q)
-
- # Partition tensors according to labels. `partitions` is a list of lists, of
- # size num_classes X len(tensor_list). The number of tensors in partition `i`
- # should be the same for all tensors.
- all_partitions = [data_flow_ops.dynamic_partition(data, labels, num_classes)
- for data in tensor_list]
- partitions = [[cur_partition[i] for cur_partition in all_partitions] for i in
- range(num_classes)]
-
- # Enqueue each tensor on the per-class-queue.
- for i in range(num_classes):
- enqueue_op = queues[i].enqueue_many(partitions[i]),
- queue_runner.add_queue_runner(queue_runner.QueueRunner(
- queues[i], [enqueue_op] * threads_per_queue))
-
- return queues
-
-
-def _get_batch_from_per_class_queues(per_class_queues, probs, batch_size):
- """Generates batches according to per-class-probabilities."""
- num_classes = probs.get_shape().num_elements()
- # Number of examples per class is governed by a multinomial distribution.
- # Note: multinomial takes unnormalized log probabilities for its first
- # argument, of dimension [batch_size, num_classes].
- examples = random_ops.multinomial(
- array_ops.expand_dims(math_ops.log(probs), 0), batch_size)
-
- # Prepare the data and label batches.
- val_list = []
- label_list = []
- for i in range(num_classes):
- num_examples = math_ops.reduce_sum(
- math_ops.cast(math_ops.equal(examples, i), dtypes.int32))
- tensors = per_class_queues[i].dequeue_many(num_examples)
-
- # If you enqueue a list with a single tensor, only a single tensor is
- # returned. If you enqueue a list with multiple tensors, then a list is
- # returned. We want to handle both cases, so reduce the case of the single
- # tensor to the case of multiple tensors.
- if not isinstance(tensors, list):
- tensors = [tensors]
-
- val_list.append(tensors)
- label_list.append(array_ops.ones([num_examples], dtype=dtypes.int32) * i)
-
- # Create a list of tensor of values. val_list is of dimension
- # [num_classes x len(tensors)]. We want list_batch_vals to be of dimension
- # [len(tensors)].
- num_data = len(val_list[0])
- list_batch_vals = [array_ops.concat(
- 0, [val_list[i][j] for i in range(num_classes)]) for j in range(num_data)]
-
- # Create a tensor of labels.
- batch_labels = array_ops.concat(0, label_list)
- batch_labels.set_shape([batch_size])
-
- # Debug instrumentation.
- sample_tags = ['stratified_sample/%s/samples_class%i' % (batch_labels.name, i)
- for i in range(num_classes)]
- logging_ops.scalar_summary(sample_tags, math_ops.reduce_sum(
- array_ops.one_hot(batch_labels, num_classes), 0))
-
- return list_batch_vals, batch_labels
diff --git a/tensorflow/contrib/training/python/training/sampling_ops_test.py b/tensorflow/contrib/training/python/training/sampling_ops_test.py
index bbc0a284cd..40c9c0baf1 100644
--- a/tensorflow/contrib/training/python/training/sampling_ops_test.py
+++ b/tensorflow/contrib/training/python/training/sampling_ops_test.py
@@ -30,58 +30,49 @@ class StratifiedSampleTest(tf.test.TestCase):
val = [tf.zeros([1, 3]), tf.ones([1, 5])]
label = tf.constant([1], shape=[1]) # must have batch dimension
probs = [.2] * 5
- initial_p = [.1, .3, .1, .3, .2] # only used for stratified_sample
+ init_probs = [.1, .3, .1, .3, .2]
batch_size = 16
- # Curry the rejection sampler so we can easily run the same tests on both
- # stratified_sample and stratified_sample_unknown_dist.
- def curried_sampler(tensors, labels, probs, batch_size, enqueue_many=True):
- return tf.contrib.training.stratified_sample(
- tensors=tensors,
- labels=labels,
- target_probs=probs,
- batch_size=batch_size,
- init_probs=initial_p,
- enqueue_many=enqueue_many)
-
- samplers = [
- tf.contrib.training.stratified_sample_unknown_dist,
- curried_sampler,
- ]
-
- for sampler in samplers:
- logging.info('Now testing `%s`', sampler.__class__.__name__)
- # Label must have only batch dimension if enqueue_many is True.
- with self.assertRaises(ValueError):
- sampler(val, tf.zeros([]), probs, batch_size, enqueue_many=True)
- with self.assertRaises(ValueError):
- sampler(val, tf.zeros([1, 1]), probs, batch_size, enqueue_many=True)
+ # Label must have only batch dimension if enqueue_many is True.
+ with self.assertRaises(ValueError):
+ tf.contrib.training.stratified_sample(
+ val, tf.zeros([]), probs, batch_size, init_probs, enqueue_many=True)
+ with self.assertRaises(ValueError):
+ tf.contrib.training.stratified_sample(
+ val, tf.zeros([1, 1]), probs, batch_size, init_probs,
+ enqueue_many=True)
- # Label must not be one-hot.
- with self.assertRaises(ValueError):
- sampler(val, tf.constant([0, 1, 0, 0, 0]), probs, batch_size)
+ # Label must not be one-hot.
+ with self.assertRaises(ValueError):
+ tf.contrib.training.stratified_sample(
+ val, tf.constant([0, 1, 0, 0, 0]), probs, batch_size, init_probs)
- # Data must be list, not singleton tensor.
- with self.assertRaises(TypeError):
- sampler(tf.zeros([1, 3]), label, probs, batch_size)
+ # Data must be list, not singleton tensor.
+ with self.assertRaises(TypeError):
+ tf.contrib.training.stratified_sample(
+ tf.zeros([1, 3]), label, probs, batch_size, init_probs)
- # Data must have batch dimension if enqueue_many is True.
- with self.assertRaises(ValueError):
- sampler(val, tf.constant(1), probs, batch_size, enqueue_many=True)
+ # Data must have batch dimension if enqueue_many is True.
+ with self.assertRaises(ValueError):
+ tf.contrib.training.stratified_sample(
+ val, tf.constant(1), probs, batch_size, init_probs, enqueue_many=True)
- # Batch dimensions on data and labels should be equal.
- with self.assertRaises(ValueError):
- sampler([tf.zeros([2, 1])], label, probs, batch_size, enqueue_many=True)
+ # Batch dimensions on data and labels should be equal.
+ with self.assertRaises(ValueError):
+ tf.contrib.training.stratified_sample(
+ [tf.zeros([2, 1])], label, probs, batch_size, init_probs,
+ enqueue_many=True)
- # Probabilities must be numpy array, python list, or tensor.
- with self.assertRaises(ValueError):
- sampler(val, label, 1, batch_size)
+ # Probabilities must be numpy array, python list, or tensor.
+ with self.assertRaises(ValueError):
+ tf.contrib.training.stratified_sample(
+ val, label, 1, batch_size, init_probs)
- # Probabilities shape must be fully defined.
- with self.assertRaises(ValueError):
- sampler(
- val, label, tf.placeholder(
- tf.float32, shape=[None]), batch_size)
+ # Probabilities shape must be fully defined.
+ with self.assertRaises(ValueError):
+ tf.contrib.training.stratified_sample(
+ val, label, tf.placeholder(
+ tf.float32, shape=[None]), batch_size, init_probs)
# In the rejection sampling case, make sure that probability lengths are
# the same.
@@ -95,11 +86,6 @@ class StratifiedSampleTest(tf.test.TestCase):
tf.contrib.training.stratified_sample(
val, label, [.2, .4, .4], batch_size, init_probs=[0, .5, .5])
- # Probabilities must be 1D.
- with self.assertRaises(ValueError):
- tf.contrib.training.stratified_sample_unknown_dist(
- val, label, np.array([[.25, .25], [.25, .25]]), batch_size)
-
def testRuntimeAssertionFailures(self):
valid_probs = [.2] * 5
valid_labels = [1, 2, 3]
@@ -138,26 +124,6 @@ class StratifiedSampleTest(tf.test.TestCase):
feed_dict={label_ph: valid_labels,
probs_ph: illegal_prob})
- def batchingBehaviorHelper(self, sampler):
- batch_size = 20
- input_batch_size = 11
- val_input_batch = [tf.zeros([input_batch_size, 2, 3, 4])]
- lbl_input_batch = tf.cond(
- tf.greater(.5, tf.random_uniform([])),
- lambda: tf.ones([input_batch_size], dtype=tf.int32) * 1,
- lambda: tf.ones([input_batch_size], dtype=tf.int32) * 3)
- probs = np.array([0, .2, 0, .8, 0])
- data_batch, labels = sampler(
- val_input_batch, lbl_input_batch, probs, batch_size, enqueue_many=True)
- with self.test_session() as sess:
- coord = tf.train.Coordinator()
- threads = tf.train.start_queue_runners(coord=coord)
-
- sess.run([data_batch, labels])
-
- coord.request_stop()
- coord.join(threads)
-
def testCanBeCalledMultipleTimes(self):
batch_size = 20
val_input_batch = [tf.zeros([2, 3, 4])]
@@ -167,10 +133,6 @@ class StratifiedSampleTest(tf.test.TestCase):
val_input_batch, lbl_input_batch, probs, batch_size, init_probs=probs)
batches += tf.contrib.training.stratified_sample(
val_input_batch, lbl_input_batch, probs, batch_size, init_probs=probs)
- batches += tf.contrib.training.stratified_sample_unknown_dist(
- val_input_batch, lbl_input_batch, probs, batch_size)
- batches += tf.contrib.training.stratified_sample_unknown_dist(
- val_input_batch, lbl_input_batch, probs, batch_size)
summary_op = tf.merge_summary(tf.get_collection(tf.GraphKeys.SUMMARIES))
with self.test_session() as sess:
@@ -182,58 +144,23 @@ class StratifiedSampleTest(tf.test.TestCase):
coord.request_stop()
coord.join(threads)
- def testBatchingBehavior(self):
- self.batchingBehaviorHelper(
- tf.contrib.training.stratified_sample_unknown_dist)
-
def testRejectionBatchingBehavior(self):
- initial_p = [0, .3, 0, .7, 0]
-
- def curried_sampler(val, lbls, probs, batch, enqueue_many=True):
- return tf.contrib.training.stratified_sample(
- val,
- lbls,
- probs,
- batch,
- init_probs=initial_p,
- enqueue_many=enqueue_many)
-
- self.batchingBehaviorHelper(curried_sampler)
-
- def testProbabilitiesCanBeChanged(self):
- # Set up graph.
- tf.set_random_seed(1234)
- lbl1 = 0
- lbl2 = 3
- # This cond allows the necessary class queues to be populated.
- label = tf.cond(
- tf.greater(.5, tf.random_uniform([])), lambda: tf.constant(lbl1),
- lambda: tf.constant(lbl2))
- val = [np.array([1, 4]) * label]
- probs = tf.placeholder(tf.float32, shape=[5])
- batch_size = 2
-
- data_batch, labels = tf.contrib.training.stratified_sample_unknown_dist(
- val, label, probs, batch_size)
-
+ batch_size = 20
+ input_batch_size = 11
+ val_input_batch = [tf.zeros([input_batch_size, 2, 3, 4])]
+ lbl_input_batch = tf.cond(
+ tf.greater(.5, tf.random_uniform([])),
+ lambda: tf.ones([input_batch_size], dtype=tf.int32) * 1,
+ lambda: tf.ones([input_batch_size], dtype=tf.int32) * 3)
+ probs = np.array([0, .2, 0, .8, 0])
+ data_batch, labels = tf.contrib.training.stratified_sample(
+ val_input_batch, lbl_input_batch, probs, batch_size,
+ init_probs=[0, .3, 0, .7, 0], enqueue_many=True)
with self.test_session() as sess:
coord = tf.train.Coordinator()
threads = tf.train.start_queue_runners(coord=coord)
- for _ in range(5):
- [data], lbls = sess.run([data_batch, labels],
- feed_dict={probs: [1, 0, 0, 0, 0]})
- for data_example in data:
- self.assertListEqual([0, 0], list(data_example))
- self.assertListEqual([0, 0], list(lbls))
-
- # Now change distribution and expect different output.
- for _ in range(5):
- [data], lbls = sess.run([data_batch, labels],
- feed_dict={probs: [0, 0, 0, 1, 0]})
- for data_example in data:
- self.assertListEqual([3, 12], list(data_example))
- self.assertListEqual([3, 3], list(lbls))
+ sess.run([data_batch, labels])
coord.request_stop()
coord.join(threads)
@@ -263,13 +190,14 @@ class StratifiedSampleTest(tf.test.TestCase):
feed_dict={vals_ph: vals,
labels_ph: labels})
- def dataListHelper(self, sampler):
+ def testRejectionDataListInput(self):
batch_size = 20
val_input_batch = [tf.zeros([2, 3, 4]), tf.ones([2, 4]), tf.ones(2) * 3]
lbl_input_batch = tf.ones([], dtype=tf.int32)
probs = np.array([0, 1, 0, 0, 0])
- val_list, lbls = sampler(val_input_batch, lbl_input_batch, probs,
- batch_size)
+ val_list, lbls = tf.contrib.training.stratified_sample(
+ val_input_batch, lbl_input_batch, probs, batch_size,
+ init_probs=[0, 1, 0, 0, 0])
# Check output shapes.
self.assertTrue(isinstance(val_list, list))
@@ -288,24 +216,6 @@ class StratifiedSampleTest(tf.test.TestCase):
# Check output shapes.
self.assertEqual(len(out), len(val_input_batch) + 1)
- def testDataListInput(self):
- self.dataListHelper(
- tf.contrib.training.stratified_sample_unknown_dist)
-
- def testRejectionDataListInput(self):
- initial_p = [0, 1, 0, 0, 0]
-
- def curried_sampler(val, lbls, probs, batch, enqueue_many=False):
- return tf.contrib.training.stratified_sample(
- val,
- lbls,
- probs,
- batch,
- init_probs=initial_p,
- enqueue_many=enqueue_many)
-
- self.dataListHelper(curried_sampler)
-
def normalBehaviorHelper(self, sampler):
# Set up graph.
tf.set_random_seed(1234)
@@ -357,10 +267,6 @@ class StratifiedSampleTest(tf.test.TestCase):
# an implementation detail, which would cause the random behavior to differ.
self.assertNear(actual_lbl, expected_label, 3 * lbl_std_dev_of_mean)
- def testNormalBehavior(self):
- self.normalBehaviorHelper(
- tf.contrib.training.stratified_sample_unknown_dist)
-
def testRejectionNormalBehavior(self):
initial_p = [.7, 0, 0, .3, 0]