diff options
author | A. Unique TensorFlower <gardener@tensorflow.org> | 2016-10-17 08:51:36 -0800 |
---|---|---|
committer | TensorFlower Gardener <gardener@tensorflow.org> | 2016-10-17 10:08:42 -0700 |
commit | 83f2fe11b30c4111084f1d99f23396e8d01b45b7 (patch) | |
tree | 162a1511229b78f7f93dd874c3cf018ac133bfad /tensorflow | |
parent | 090e0743904a45bc3169b006ad90c7a4720f0998 (diff) |
Remove `stratified_sample_unknown_dist`, since `stratified_sample` supports an unknown data distribution.
Change: 136361350
Diffstat (limited to 'tensorflow')
3 files changed, 58 insertions, 317 deletions
diff --git a/tensorflow/contrib/training/__init__.py b/tensorflow/contrib/training/__init__.py index 1290854260..c9564fc316 100644 --- a/tensorflow/contrib/training/__init__.py +++ b/tensorflow/contrib/training/__init__.py @@ -37,18 +37,15 @@ providing the desired rate for each example. If you wish to specify relative rates, rather than absolute ones, use ['weighted_resample'](#weighted_resample) (which also returns the actual resampling rate used for each output example). -Use ['stratified_sample'](#stratified_sample) or -['stratified_sample_unknown_dist'](#stratified_sample_unknown_dist) to -resample without replacement from the data to achieve a desired mix of -class proportions that the Tensorflow graph sees. For instance, if you -have a binary classification dataset that is 99.9% class 1, a common -approach is to resample from the data so that the data is more -balanced. +Use ['stratified_sample'](#stratified_sample) to resample without replacement +from the data to achieve a desired mix of class proportions that the Tensorflow +graph sees. For instance, if you have a binary classification dataset that is +99.9% class 1, a common approach is to resample from the data so that the data +is more balanced. @@rejection_sample @@resample_at_rate @@stratified_sample -@@stratified_sample_unknown_dist @@weighted_resample ## Bucketing diff --git a/tensorflow/contrib/training/python/training/sampling_ops.py b/tensorflow/contrib/training/python/training/sampling_ops.py index 05f5ec6b39..395840d13e 100644 --- a/tensorflow/contrib/training/python/training/sampling_ops.py +++ b/tensorflow/contrib/training/python/training/sampling_ops.py @@ -34,8 +34,7 @@ from tensorflow.python.training import queue_runner __all__ = ['rejection_sample', - 'stratified_sample', - 'stratified_sample_unknown_dist',] + 'stratified_sample',] def rejection_sample(tensors, accept_prob_fn, batch_size, queue_threads=1, @@ -124,9 +123,7 @@ def stratified_sample(tensors, labels, target_probs, batch_size, This method discards examples. Internally, it creates one queue to amortize the cost of disk reads, and one queue to hold the properly-proportioned - batch. See `stratified_sample_unknown_dist` for a function that performs - stratified sampling with one queue per class and doesn't require knowing the - class data-distribution ahead of time. + batch. Args: tensors: List of tensors for data. All tensors are either one item or a @@ -236,80 +233,6 @@ def stratified_sample(tensors, labels, target_probs, batch_size, return batched[:-1], batched[-1] -def stratified_sample_unknown_dist(tensors, labels, probs, batch_size, - enqueue_many=False, queue_capacity=16, - threads_per_queue=1, name=None): - """Stochastically creates batches based on per-class probabilities. - - **NOTICE** This sampler can be significantly slower than `stratified_sample` - due to each thread discarding all examples not in its assigned class. - - This uses a number of threads proportional to the number of classes. See - `stratified_sample` for an implementation that discards fewer examples and - uses a fixed number of threads. This function's only advantage over - `stratified_sample` is that the class data-distribution doesn't need to be - known ahead of time. - - Args: - tensors: List of tensors for data. All tensors are either one item or a - batch, according to enqueue_many. - labels: Tensor for label of data. Label is a single integer or a batch, - depending on enqueue_many. It is not a one-hot vector. - probs: Target class probabilities. An object whose type has a registered - Tensor conversion function. - batch_size: Size of batch to be returned. - enqueue_many: Bool. If true, interpret input tensors as having a batch - dimension. - queue_capacity: Capacity of each per-class queue. - threads_per_queue: Number of threads for each per-class queue. - name: Optional prefix for ops created by this function. - Raises: - ValueError: enqueue_many is True and labels doesn't have a batch - dimension, or if enqueue_many is False and labels isn't a scalar. - ValueError: enqueue_many is True, and batch dimension of data and labels - don't match. - ValueError: if probs don't sum to one. - TFAssertion: if labels aren't integers in [0, num classes). - Returns: - (data_batch, label_batch), where data_batch is a list of tensors of the same - length as `tensors` - - Example: - # Get tensor for a single data and label example. - data, label = data_provider.Get(['data', 'label']) - - # Get stratified batch according to per-class probabilities. - init_probs = [1.0/NUM_CLASSES for _ in range(NUM_CLASSES)] - [data_batch], labels = ( - tf.contrib.training.stratified_sample_unknown_dist( - [data], label, init_probs, 16)) - - # Run batch through network. - ... - """ - with ops.name_scope(name, 'stratified_sample_unknown_dist', - tensors + [labels]): - tensor_list = ops.convert_n_to_tensor_or_indexed_slices(tensors) - labels = ops.convert_to_tensor(labels) - probs = ops.convert_to_tensor(probs, dtype=dtypes.float32) - # Reduce the case of a single example to that of a batch of size 1. - if not enqueue_many: - tensor_list = [array_ops.expand_dims(tensor, 0) for tensor in tensor_list] - labels = array_ops.expand_dims(labels, 0) - - # Validate that input is consistent. - tensor_list, labels, [probs] = _verify_input(tensor_list, labels, [probs]) - - # Make per-class queues. - per_class_queues = _make_per_class_queues( - tensor_list, labels, probs.get_shape().num_elements(), queue_capacity, - threads_per_queue) - - # Use the per-class queues to generate stratified batches. - return _get_batch_from_per_class_queues( - per_class_queues, probs, batch_size) - - def _estimate_data_distribution(labels, num_classes, smoothing_constant=10): """Estimate data distribution as labels are seen.""" # Variable to track running count of classes. Smooth by a nonzero value to @@ -521,88 +444,3 @@ def _conditional_batch(tensors, keep_input, batch_size, num_threads=10): out_tensor = [out_tensor] return out_tensor - - -def _make_per_class_queues(tensor_list, labels, num_classes, queue_capacity, - threads_per_queue): - """Creates per-class-queues based on data and labels.""" - # Create one queue per class. - queues = [] - data_shapes = [] - data_dtypes = [] - for data_tensor in tensor_list: - per_data_shape = data_tensor.get_shape().with_rank_at_least(1)[1:] - per_data_shape.assert_is_fully_defined() - data_shapes.append(per_data_shape) - data_dtypes.append(data_tensor.dtype) - - for i in range(num_classes): - q = data_flow_ops.FIFOQueue(capacity=queue_capacity, - shapes=data_shapes, dtypes=data_dtypes, - name='stratified_sample_class%d_queue' % i) - logging_ops.scalar_summary( - 'queue/%s/stratified_sample_class%d' % (q.name, i), q.size()) - queues.append(q) - - # Partition tensors according to labels. `partitions` is a list of lists, of - # size num_classes X len(tensor_list). The number of tensors in partition `i` - # should be the same for all tensors. - all_partitions = [data_flow_ops.dynamic_partition(data, labels, num_classes) - for data in tensor_list] - partitions = [[cur_partition[i] for cur_partition in all_partitions] for i in - range(num_classes)] - - # Enqueue each tensor on the per-class-queue. - for i in range(num_classes): - enqueue_op = queues[i].enqueue_many(partitions[i]), - queue_runner.add_queue_runner(queue_runner.QueueRunner( - queues[i], [enqueue_op] * threads_per_queue)) - - return queues - - -def _get_batch_from_per_class_queues(per_class_queues, probs, batch_size): - """Generates batches according to per-class-probabilities.""" - num_classes = probs.get_shape().num_elements() - # Number of examples per class is governed by a multinomial distribution. - # Note: multinomial takes unnormalized log probabilities for its first - # argument, of dimension [batch_size, num_classes]. - examples = random_ops.multinomial( - array_ops.expand_dims(math_ops.log(probs), 0), batch_size) - - # Prepare the data and label batches. - val_list = [] - label_list = [] - for i in range(num_classes): - num_examples = math_ops.reduce_sum( - math_ops.cast(math_ops.equal(examples, i), dtypes.int32)) - tensors = per_class_queues[i].dequeue_many(num_examples) - - # If you enqueue a list with a single tensor, only a single tensor is - # returned. If you enqueue a list with multiple tensors, then a list is - # returned. We want to handle both cases, so reduce the case of the single - # tensor to the case of multiple tensors. - if not isinstance(tensors, list): - tensors = [tensors] - - val_list.append(tensors) - label_list.append(array_ops.ones([num_examples], dtype=dtypes.int32) * i) - - # Create a list of tensor of values. val_list is of dimension - # [num_classes x len(tensors)]. We want list_batch_vals to be of dimension - # [len(tensors)]. - num_data = len(val_list[0]) - list_batch_vals = [array_ops.concat( - 0, [val_list[i][j] for i in range(num_classes)]) for j in range(num_data)] - - # Create a tensor of labels. - batch_labels = array_ops.concat(0, label_list) - batch_labels.set_shape([batch_size]) - - # Debug instrumentation. - sample_tags = ['stratified_sample/%s/samples_class%i' % (batch_labels.name, i) - for i in range(num_classes)] - logging_ops.scalar_summary(sample_tags, math_ops.reduce_sum( - array_ops.one_hot(batch_labels, num_classes), 0)) - - return list_batch_vals, batch_labels diff --git a/tensorflow/contrib/training/python/training/sampling_ops_test.py b/tensorflow/contrib/training/python/training/sampling_ops_test.py index bbc0a284cd..40c9c0baf1 100644 --- a/tensorflow/contrib/training/python/training/sampling_ops_test.py +++ b/tensorflow/contrib/training/python/training/sampling_ops_test.py @@ -30,58 +30,49 @@ class StratifiedSampleTest(tf.test.TestCase): val = [tf.zeros([1, 3]), tf.ones([1, 5])] label = tf.constant([1], shape=[1]) # must have batch dimension probs = [.2] * 5 - initial_p = [.1, .3, .1, .3, .2] # only used for stratified_sample + init_probs = [.1, .3, .1, .3, .2] batch_size = 16 - # Curry the rejection sampler so we can easily run the same tests on both - # stratified_sample and stratified_sample_unknown_dist. - def curried_sampler(tensors, labels, probs, batch_size, enqueue_many=True): - return tf.contrib.training.stratified_sample( - tensors=tensors, - labels=labels, - target_probs=probs, - batch_size=batch_size, - init_probs=initial_p, - enqueue_many=enqueue_many) - - samplers = [ - tf.contrib.training.stratified_sample_unknown_dist, - curried_sampler, - ] - - for sampler in samplers: - logging.info('Now testing `%s`', sampler.__class__.__name__) - # Label must have only batch dimension if enqueue_many is True. - with self.assertRaises(ValueError): - sampler(val, tf.zeros([]), probs, batch_size, enqueue_many=True) - with self.assertRaises(ValueError): - sampler(val, tf.zeros([1, 1]), probs, batch_size, enqueue_many=True) + # Label must have only batch dimension if enqueue_many is True. + with self.assertRaises(ValueError): + tf.contrib.training.stratified_sample( + val, tf.zeros([]), probs, batch_size, init_probs, enqueue_many=True) + with self.assertRaises(ValueError): + tf.contrib.training.stratified_sample( + val, tf.zeros([1, 1]), probs, batch_size, init_probs, + enqueue_many=True) - # Label must not be one-hot. - with self.assertRaises(ValueError): - sampler(val, tf.constant([0, 1, 0, 0, 0]), probs, batch_size) + # Label must not be one-hot. + with self.assertRaises(ValueError): + tf.contrib.training.stratified_sample( + val, tf.constant([0, 1, 0, 0, 0]), probs, batch_size, init_probs) - # Data must be list, not singleton tensor. - with self.assertRaises(TypeError): - sampler(tf.zeros([1, 3]), label, probs, batch_size) + # Data must be list, not singleton tensor. + with self.assertRaises(TypeError): + tf.contrib.training.stratified_sample( + tf.zeros([1, 3]), label, probs, batch_size, init_probs) - # Data must have batch dimension if enqueue_many is True. - with self.assertRaises(ValueError): - sampler(val, tf.constant(1), probs, batch_size, enqueue_many=True) + # Data must have batch dimension if enqueue_many is True. + with self.assertRaises(ValueError): + tf.contrib.training.stratified_sample( + val, tf.constant(1), probs, batch_size, init_probs, enqueue_many=True) - # Batch dimensions on data and labels should be equal. - with self.assertRaises(ValueError): - sampler([tf.zeros([2, 1])], label, probs, batch_size, enqueue_many=True) + # Batch dimensions on data and labels should be equal. + with self.assertRaises(ValueError): + tf.contrib.training.stratified_sample( + [tf.zeros([2, 1])], label, probs, batch_size, init_probs, + enqueue_many=True) - # Probabilities must be numpy array, python list, or tensor. - with self.assertRaises(ValueError): - sampler(val, label, 1, batch_size) + # Probabilities must be numpy array, python list, or tensor. + with self.assertRaises(ValueError): + tf.contrib.training.stratified_sample( + val, label, 1, batch_size, init_probs) - # Probabilities shape must be fully defined. - with self.assertRaises(ValueError): - sampler( - val, label, tf.placeholder( - tf.float32, shape=[None]), batch_size) + # Probabilities shape must be fully defined. + with self.assertRaises(ValueError): + tf.contrib.training.stratified_sample( + val, label, tf.placeholder( + tf.float32, shape=[None]), batch_size, init_probs) # In the rejection sampling case, make sure that probability lengths are # the same. @@ -95,11 +86,6 @@ class StratifiedSampleTest(tf.test.TestCase): tf.contrib.training.stratified_sample( val, label, [.2, .4, .4], batch_size, init_probs=[0, .5, .5]) - # Probabilities must be 1D. - with self.assertRaises(ValueError): - tf.contrib.training.stratified_sample_unknown_dist( - val, label, np.array([[.25, .25], [.25, .25]]), batch_size) - def testRuntimeAssertionFailures(self): valid_probs = [.2] * 5 valid_labels = [1, 2, 3] @@ -138,26 +124,6 @@ class StratifiedSampleTest(tf.test.TestCase): feed_dict={label_ph: valid_labels, probs_ph: illegal_prob}) - def batchingBehaviorHelper(self, sampler): - batch_size = 20 - input_batch_size = 11 - val_input_batch = [tf.zeros([input_batch_size, 2, 3, 4])] - lbl_input_batch = tf.cond( - tf.greater(.5, tf.random_uniform([])), - lambda: tf.ones([input_batch_size], dtype=tf.int32) * 1, - lambda: tf.ones([input_batch_size], dtype=tf.int32) * 3) - probs = np.array([0, .2, 0, .8, 0]) - data_batch, labels = sampler( - val_input_batch, lbl_input_batch, probs, batch_size, enqueue_many=True) - with self.test_session() as sess: - coord = tf.train.Coordinator() - threads = tf.train.start_queue_runners(coord=coord) - - sess.run([data_batch, labels]) - - coord.request_stop() - coord.join(threads) - def testCanBeCalledMultipleTimes(self): batch_size = 20 val_input_batch = [tf.zeros([2, 3, 4])] @@ -167,10 +133,6 @@ class StratifiedSampleTest(tf.test.TestCase): val_input_batch, lbl_input_batch, probs, batch_size, init_probs=probs) batches += tf.contrib.training.stratified_sample( val_input_batch, lbl_input_batch, probs, batch_size, init_probs=probs) - batches += tf.contrib.training.stratified_sample_unknown_dist( - val_input_batch, lbl_input_batch, probs, batch_size) - batches += tf.contrib.training.stratified_sample_unknown_dist( - val_input_batch, lbl_input_batch, probs, batch_size) summary_op = tf.merge_summary(tf.get_collection(tf.GraphKeys.SUMMARIES)) with self.test_session() as sess: @@ -182,58 +144,23 @@ class StratifiedSampleTest(tf.test.TestCase): coord.request_stop() coord.join(threads) - def testBatchingBehavior(self): - self.batchingBehaviorHelper( - tf.contrib.training.stratified_sample_unknown_dist) - def testRejectionBatchingBehavior(self): - initial_p = [0, .3, 0, .7, 0] - - def curried_sampler(val, lbls, probs, batch, enqueue_many=True): - return tf.contrib.training.stratified_sample( - val, - lbls, - probs, - batch, - init_probs=initial_p, - enqueue_many=enqueue_many) - - self.batchingBehaviorHelper(curried_sampler) - - def testProbabilitiesCanBeChanged(self): - # Set up graph. - tf.set_random_seed(1234) - lbl1 = 0 - lbl2 = 3 - # This cond allows the necessary class queues to be populated. - label = tf.cond( - tf.greater(.5, tf.random_uniform([])), lambda: tf.constant(lbl1), - lambda: tf.constant(lbl2)) - val = [np.array([1, 4]) * label] - probs = tf.placeholder(tf.float32, shape=[5]) - batch_size = 2 - - data_batch, labels = tf.contrib.training.stratified_sample_unknown_dist( - val, label, probs, batch_size) - + batch_size = 20 + input_batch_size = 11 + val_input_batch = [tf.zeros([input_batch_size, 2, 3, 4])] + lbl_input_batch = tf.cond( + tf.greater(.5, tf.random_uniform([])), + lambda: tf.ones([input_batch_size], dtype=tf.int32) * 1, + lambda: tf.ones([input_batch_size], dtype=tf.int32) * 3) + probs = np.array([0, .2, 0, .8, 0]) + data_batch, labels = tf.contrib.training.stratified_sample( + val_input_batch, lbl_input_batch, probs, batch_size, + init_probs=[0, .3, 0, .7, 0], enqueue_many=True) with self.test_session() as sess: coord = tf.train.Coordinator() threads = tf.train.start_queue_runners(coord=coord) - for _ in range(5): - [data], lbls = sess.run([data_batch, labels], - feed_dict={probs: [1, 0, 0, 0, 0]}) - for data_example in data: - self.assertListEqual([0, 0], list(data_example)) - self.assertListEqual([0, 0], list(lbls)) - - # Now change distribution and expect different output. - for _ in range(5): - [data], lbls = sess.run([data_batch, labels], - feed_dict={probs: [0, 0, 0, 1, 0]}) - for data_example in data: - self.assertListEqual([3, 12], list(data_example)) - self.assertListEqual([3, 3], list(lbls)) + sess.run([data_batch, labels]) coord.request_stop() coord.join(threads) @@ -263,13 +190,14 @@ class StratifiedSampleTest(tf.test.TestCase): feed_dict={vals_ph: vals, labels_ph: labels}) - def dataListHelper(self, sampler): + def testRejectionDataListInput(self): batch_size = 20 val_input_batch = [tf.zeros([2, 3, 4]), tf.ones([2, 4]), tf.ones(2) * 3] lbl_input_batch = tf.ones([], dtype=tf.int32) probs = np.array([0, 1, 0, 0, 0]) - val_list, lbls = sampler(val_input_batch, lbl_input_batch, probs, - batch_size) + val_list, lbls = tf.contrib.training.stratified_sample( + val_input_batch, lbl_input_batch, probs, batch_size, + init_probs=[0, 1, 0, 0, 0]) # Check output shapes. self.assertTrue(isinstance(val_list, list)) @@ -288,24 +216,6 @@ class StratifiedSampleTest(tf.test.TestCase): # Check output shapes. self.assertEqual(len(out), len(val_input_batch) + 1) - def testDataListInput(self): - self.dataListHelper( - tf.contrib.training.stratified_sample_unknown_dist) - - def testRejectionDataListInput(self): - initial_p = [0, 1, 0, 0, 0] - - def curried_sampler(val, lbls, probs, batch, enqueue_many=False): - return tf.contrib.training.stratified_sample( - val, - lbls, - probs, - batch, - init_probs=initial_p, - enqueue_many=enqueue_many) - - self.dataListHelper(curried_sampler) - def normalBehaviorHelper(self, sampler): # Set up graph. tf.set_random_seed(1234) @@ -357,10 +267,6 @@ class StratifiedSampleTest(tf.test.TestCase): # an implementation detail, which would cause the random behavior to differ. self.assertNear(actual_lbl, expected_label, 3 * lbl_std_dev_of_mean) - def testNormalBehavior(self): - self.normalBehaviorHelper( - tf.contrib.training.stratified_sample_unknown_dist) - def testRejectionNormalBehavior(self): initial_p = [.7, 0, 0, .3, 0] |