diff options
author | 2016-10-11 17:34:13 -0800 | |
---|---|---|
committer | 2016-10-11 18:50:14 -0700 | |
commit | 2c183364e6994ef00d3cff930cb383c4e9443f25 (patch) | |
tree | 95abd17465954d768d834f2afeea666f3f208f9f /tensorflow | |
parent | 40fc65b0f5791f40132fbd173bfde7521f14fed5 (diff) |
Check in code to perform rejection sampling.
Change: 135869416
Diffstat (limited to 'tensorflow')
3 files changed, 181 insertions, 19 deletions
diff --git a/tensorflow/contrib/training/__init__.py b/tensorflow/contrib/training/__init__.py index fc0e324bcf..1290854260 100644 --- a/tensorflow/contrib/training/__init__.py +++ b/tensorflow/contrib/training/__init__.py @@ -30,10 +30,12 @@ like to store state in the forward direction across segments of an example. ## Online data resampling To resample data with replacement on a per-example basis, use -['resample_at_rate'](#resample_at_rate), providing the desired rate -for each example. If you wish to specify relative rates, rather than -absolute ones, use ['weighted_resample'](#weighted_resample) (which -also returns the actual resampling rate used for each output example). +['rejection_sample'](#rejection_sample) or +['resample_at_rate'](#resample_at_rate). For `rejection_sample`, provide +a boolean Tensor describing whether to accept or reject. For `resample_at_rate`, +providing the desired rate for each example. If you wish to specify relative +rates, rather than absolute ones, use ['weighted_resample'](#weighted_resample) +(which also returns the actual resampling rate used for each output example). Use ['stratified_sample'](#stratified_sample) or ['stratified_sample_unknown_dist'](#stratified_sample_unknown_dist) to @@ -43,6 +45,7 @@ have a binary classification dataset that is 99.9% class 1, a common approach is to resample from the data so that the data is more balanced. +@@rejection_sample @@resample_at_rate @@stratified_sample @@stratified_sample_unknown_dist diff --git a/tensorflow/contrib/training/python/training/sampling_ops.py b/tensorflow/contrib/training/python/training/sampling_ops.py index c703e22e24..05f5ec6b39 100644 --- a/tensorflow/contrib/training/python/training/sampling_ops.py +++ b/tensorflow/contrib/training/python/training/sampling_ops.py @@ -27,14 +27,96 @@ from tensorflow.python.ops import data_flow_ops from tensorflow.python.ops import logging_ops from tensorflow.python.ops import math_ops from tensorflow.python.ops import random_ops +from tensorflow.python.ops import variable_scope from tensorflow.python.ops import variables from tensorflow.python.training import input as input_ops from tensorflow.python.training import queue_runner -__all__ = ['stratified_sample', + +__all__ = ['rejection_sample', + 'stratified_sample', 'stratified_sample_unknown_dist',] +def rejection_sample(tensors, accept_prob_fn, batch_size, queue_threads=1, + enqueue_many=False, prebatch_capacity=16, + prebatch_threads=1, runtime_checks=False, name=None): + """Stochastically creates batches by rejection sampling. + + Each list of non-batched tensors is evaluated by `accept_prob_fn`, to produce + a scalar tensor between 0 and 1. This tensor corresponds to the probability of + being accepted. When `batch_size` tensor groups have been accepted, the batch + queue will return a mini-batch. + + Args: + tensors: List of tensors for data. All tensors are either one item or a + batch, according to enqueue_many. + accept_prob_fn: A python lambda that takes a non-batch tensor from each + item in `tensors`, and produces a scalar tensor. + batch_size: Size of batch to be returned. + queue_threads: The number of threads for the queue that will hold the final + batch. + enqueue_many: Bool. If true, interpret input tensors as having a batch + dimension. + prebatch_capacity: Capacity for the large queue that is used to convert + batched tensors to single examples. + prebatch_threads: Number of threads for the large queue that is used to + convert batched tensors to single examples. + runtime_checks: Bool. If true, insert runtime checks on the output of + `accept_prob_fn`. Using `True` might have a performance impact. + name: Optional prefix for ops created by this function. + Raises: + ValueError: enqueue_many is True and labels doesn't have a batch + dimension, or if enqueue_many is False and labels isn't a scalar. + ValueError: enqueue_many is True, and batch dimension on data and labels + don't match. + ValueError: if a zero initial probability class has a nonzero target + probability. + Returns: + A list of tensors of the same length as `tensors`, with batch dimension + `batch_size`. + + Example: + # Get tensor for a single data and label example. + data, label = data_provider.Get(['data', 'label']) + + # Get stratified batch according to data tensor. + accept_prob_fn = lambda x: (tf.tanh(x[0]) + 1) / 2 + data_batch = tf.contrib.training.rejection_sample( + [data, label], accept_prob_fn, 16) + + # Run batch through network. + ... + """ + with variable_scope.variable_scope(name, 'rejection_sample', tensors): + tensor_list = ops.convert_n_to_tensor_or_indexed_slices(tensors) + # Reduce the case of a batched example to that of a batch of a single + # example by taking a batch of size one. + if enqueue_many: + # Validate that batch dimension of the input is consistent. + tensor_list = _verify_data_inputs(tensor_list) + + # Make a single queue to hold input examples. Reshape output so examples + # don't have singleton batch dimension. + batched = input_ops.batch(tensor_list, + batch_size=1, + num_threads=prebatch_threads, + capacity=prebatch_capacity, + enqueue_many=True) + tensor_list = [array_ops.squeeze(x, [0]) for x in batched] + + # Set up a queue containing batches that have the distribution. + cur_prob = accept_prob_fn(tensor_list) + if runtime_checks: + cur_prob = array_ops.identity(control_flow_ops.with_dependencies( + [check_ops.assert_less_equal(0.0, cur_prob), + check_ops.assert_less_equal(cur_prob, 1.0)], + cur_prob), name='prob_with_checks') + keep_input = random_ops.random_uniform([]) < cur_prob + return _conditional_batch( + tensor_list, keep_input, batch_size, num_threads=queue_threads) + + def stratified_sample(tensors, labels, target_probs, batch_size, init_probs=None, enqueue_many=False, queue_capacity=16, threads_per_queue=1, name=None): @@ -145,8 +227,12 @@ def stratified_sample(tensors, labels, target_probs, batch_size, # Set up second queue containing batches that have the desired class # proportions. cur_prob = array_ops.gather(accept_probs, label) + keep_input = random_ops.random_uniform([]) < cur_prob batched = _conditional_batch( - val_list + [label], cur_prob, batch_size, threads_per_queue) + val_list + [label], + keep_input, + batch_size, + num_threads=threads_per_queue) return batched[:-1], batched[-1] @@ -260,6 +346,18 @@ def _estimate_data_distribution(labels, num_classes, smoothing_constant=10): return math_ops.cast(init_prob_estimate, dtypes.float32) +def _verify_data_inputs(tensor_list): + """Verify that batched data inputs are well-formed.""" + for tensor in tensor_list: + # Data tensor should have a batch dimension. + tensor_shape = tensor.get_shape().with_rank_at_least(1) + + # Data batch dimensions must be compatible. + tensor_shape[0].assert_is_compatible_with(tensor_list[0].get_shape()[0]) + + return tensor_list + + def _verify_input(tensor_list, labels, probs_list): """Verify that batched inputs are well-formed.""" checked_probs_list = [] @@ -374,16 +472,16 @@ def _calculate_acceptance_probabilities(init_probs, target_probs): return ratio_l / max_ratio -def _conditional_batch(tensors, accept_prob, batch_size, queue_threads=10): +def _conditional_batch(tensors, keep_input, batch_size, num_threads=10): """Conditionally enqueue tensors based on accept_prob. Specifically, enqueue the element if accept_prob > rand_unif([0, 1]). Args: tensors: List of tensors to enqueue. - accept_prob: Acceptance probability per example. + keep_input: Bool. Whether to enqueue or not. batch_size: Size of batch. - queue_threads: Number of threads enqueuing in the final queue. + num_threads: Number of enqueueing threads. Returns: List of batched tensors. @@ -391,7 +489,7 @@ def _conditional_batch(tensors, accept_prob, batch_size, queue_threads=10): Raises: ValueError: `accept_prob` isn't 0D. """ - accept_prob.get_shape().assert_has_rank(0) + keep_input.get_shape().assert_has_rank(0) # Determine shapes and types of to-be-enqueued-tensors. shapes_list = [] dtypes_list = [] @@ -409,13 +507,12 @@ def _conditional_batch(tensors, accept_prob, batch_size, queue_threads=10): # Conditionally enqueue. # Reshape enqueue op to match no_op's shape. - eq_tf = math_ops.less(random_ops.random_uniform([]), accept_prob) conditional_enqueue = control_flow_ops.cond( - eq_tf, + keep_input, lambda: final_q.enqueue(tensors), control_flow_ops.no_op) queue_runner.add_queue_runner(queue_runner.QueueRunner( - final_q, [conditional_enqueue] * queue_threads)) + final_q, [conditional_enqueue] * num_threads)) out_tensor = final_q.dequeue_many(batch_size) # Queues return a single tensor if the list of enqued tensors is one. Since we diff --git a/tensorflow/contrib/training/python/training/sampling_ops_test.py b/tensorflow/contrib/training/python/training/sampling_ops_test.py index 2d663d7954..bbc0a284cd 100644 --- a/tensorflow/contrib/training/python/training/sampling_ops_test.py +++ b/tensorflow/contrib/training/python/training/sampling_ops_test.py @@ -24,7 +24,7 @@ from tensorflow.contrib.training.python.training import sampling_ops from tensorflow.python.platform import tf_logging as logging -class SamplingOpsTest(tf.test.TestCase): +class StratifiedSampleTest(tf.test.TestCase): def testGraphBuildAssertionFailures(self): val = [tf.zeros([1, 3]), tf.ones([1, 5])] @@ -383,17 +383,79 @@ class SamplingOpsTest(tf.test.TestCase): self.normalBehaviorHelper(curried_sampler) + +class RejectionSampleTest(tf.test.TestCase): + + def testGraphConstructionFailures(self): + accept_prob_fn = lambda _: tf.constant(1.0) + batch_size = 32 + # Data must have batch dimension if `enqueue_many` is `True`. + with self.assertRaises(ValueError): + tf.contrib.training.rejection_sample( + [tf.zeros([])], accept_prob_fn, batch_size, enqueue_many=True) + + # Batch dimensions should be equal if `enqueue_many` is `True`. + with self.assertRaises(ValueError): + tf.contrib.training.rejection_sample( + [tf.zeros([5, 1]), tf.zeros([4, 1])], accept_prob_fn, batch_size, + enqueue_many=True) + + def testRuntimeFailures(self): + prob_ph = tf.placeholder(tf.float32, []) + accept_prob_fn = lambda _: prob_ph + batch_size = 32 + + # Set up graph. + tf.set_random_seed(1234) + tf.contrib.training.rejection_sample( + [tf.zeros([])], accept_prob_fn, batch_size, runtime_checks=True, + name='rejection_sample') + prob_tensor = tf.get_default_graph().get_tensor_by_name( + 'rejection_sample/prob_with_checks:0') + + # Run session that should fail. + with self.test_session() as sess: + for illegal_prob in [-0.1, 1.1]: + with self.assertRaises(tf.errors.InvalidArgumentError): + sess.run(prob_tensor, feed_dict={prob_ph: illegal_prob}) + + def testNormalBehavior(self): + tensor_list = [tf.cond( + tf.greater(.5, tf.random_uniform([])), + lambda: tf.constant(1.0), + lambda: tf.constant(2.0))] + accept_prob_fn = lambda x: x[0] - 1.0 + batch_size = 10 + + # Set up graph. + sample = tf.contrib.training.rejection_sample( + tensor_list, accept_prob_fn, batch_size) + + with self.test_session() as sess: + coord = tf.train.Coordinator() + threads = tf.train.start_queue_runners(coord=coord) + + for _ in range(5): + sample_np = sess.run(sample)[0] + self.assertListEqual([2.0] * batch_size, list(sample_np)) + + coord.request_stop() + coord.join(threads) + + +class ConditionalBatchTest(tf.test.TestCase): + def testConditionallyEnqueueAndBatch(self): tf.set_random_seed(1234) tensor = tf.cond( tf.greater(.5, tf.random_uniform([])), lambda: tf.constant(1.0), lambda: tf.constant(2.0)) - accept_prob = tensor - 1 + keep_input = tf.equal(tensor, 2.0) batch_size = 4 # Set up the test graph. - [batch] = sampling_ops._conditional_batch([tensor], accept_prob, batch_size) # pylint: disable=protected-access + [batch] = sampling_ops._conditional_batch([tensor], keep_input, batch_size) # pylint: disable=protected-access # Check conditional operation. with self.test_session(): @@ -411,13 +473,13 @@ class SamplingOpsTest(tf.test.TestCase): def testConditionallyEnqueueAndBatchTypes(self): tensor = tf.constant(1.0) - accept_prob = tensor - 1 + keep_input = tf.constant(True) batch_size = 4 # Check that output types are the same for 1 and 2-length input lists. - output1 = sampling_ops._conditional_batch([tensor], accept_prob, batch_size) # pylint: disable=protected-access + output1 = sampling_ops._conditional_batch([tensor], keep_input, batch_size) # pylint: disable=protected-access output2 = sampling_ops._conditional_batch( # pylint: disable=protected-access - [tensor, tensor], accept_prob, batch_size) + [tensor, tensor], keep_input, batch_size) self.assertEqual(type(output1), type(output2)) |