aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow
diff options
context:
space:
mode:
authorGravatar A. Unique TensorFlower <gardener@tensorflow.org>2016-10-11 17:34:13 -0800
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2016-10-11 18:50:14 -0700
commit2c183364e6994ef00d3cff930cb383c4e9443f25 (patch)
tree95abd17465954d768d834f2afeea666f3f208f9f /tensorflow
parent40fc65b0f5791f40132fbd173bfde7521f14fed5 (diff)
Check in code to perform rejection sampling.
Change: 135869416
Diffstat (limited to 'tensorflow')
-rw-r--r--tensorflow/contrib/training/__init__.py11
-rw-r--r--tensorflow/contrib/training/python/training/sampling_ops.py115
-rw-r--r--tensorflow/contrib/training/python/training/sampling_ops_test.py74
3 files changed, 181 insertions, 19 deletions
diff --git a/tensorflow/contrib/training/__init__.py b/tensorflow/contrib/training/__init__.py
index fc0e324bcf..1290854260 100644
--- a/tensorflow/contrib/training/__init__.py
+++ b/tensorflow/contrib/training/__init__.py
@@ -30,10 +30,12 @@ like to store state in the forward direction across segments of an example.
## Online data resampling
To resample data with replacement on a per-example basis, use
-['resample_at_rate'](#resample_at_rate), providing the desired rate
-for each example. If you wish to specify relative rates, rather than
-absolute ones, use ['weighted_resample'](#weighted_resample) (which
-also returns the actual resampling rate used for each output example).
+['rejection_sample'](#rejection_sample) or
+['resample_at_rate'](#resample_at_rate). For `rejection_sample`, provide
+a boolean Tensor describing whether to accept or reject. For `resample_at_rate`,
+providing the desired rate for each example. If you wish to specify relative
+rates, rather than absolute ones, use ['weighted_resample'](#weighted_resample)
+(which also returns the actual resampling rate used for each output example).
Use ['stratified_sample'](#stratified_sample) or
['stratified_sample_unknown_dist'](#stratified_sample_unknown_dist) to
@@ -43,6 +45,7 @@ have a binary classification dataset that is 99.9% class 1, a common
approach is to resample from the data so that the data is more
balanced.
+@@rejection_sample
@@resample_at_rate
@@stratified_sample
@@stratified_sample_unknown_dist
diff --git a/tensorflow/contrib/training/python/training/sampling_ops.py b/tensorflow/contrib/training/python/training/sampling_ops.py
index c703e22e24..05f5ec6b39 100644
--- a/tensorflow/contrib/training/python/training/sampling_ops.py
+++ b/tensorflow/contrib/training/python/training/sampling_ops.py
@@ -27,14 +27,96 @@ from tensorflow.python.ops import data_flow_ops
from tensorflow.python.ops import logging_ops
from tensorflow.python.ops import math_ops
from tensorflow.python.ops import random_ops
+from tensorflow.python.ops import variable_scope
from tensorflow.python.ops import variables
from tensorflow.python.training import input as input_ops
from tensorflow.python.training import queue_runner
-__all__ = ['stratified_sample',
+
+__all__ = ['rejection_sample',
+ 'stratified_sample',
'stratified_sample_unknown_dist',]
+def rejection_sample(tensors, accept_prob_fn, batch_size, queue_threads=1,
+ enqueue_many=False, prebatch_capacity=16,
+ prebatch_threads=1, runtime_checks=False, name=None):
+ """Stochastically creates batches by rejection sampling.
+
+ Each list of non-batched tensors is evaluated by `accept_prob_fn`, to produce
+ a scalar tensor between 0 and 1. This tensor corresponds to the probability of
+ being accepted. When `batch_size` tensor groups have been accepted, the batch
+ queue will return a mini-batch.
+
+ Args:
+ tensors: List of tensors for data. All tensors are either one item or a
+ batch, according to enqueue_many.
+ accept_prob_fn: A python lambda that takes a non-batch tensor from each
+ item in `tensors`, and produces a scalar tensor.
+ batch_size: Size of batch to be returned.
+ queue_threads: The number of threads for the queue that will hold the final
+ batch.
+ enqueue_many: Bool. If true, interpret input tensors as having a batch
+ dimension.
+ prebatch_capacity: Capacity for the large queue that is used to convert
+ batched tensors to single examples.
+ prebatch_threads: Number of threads for the large queue that is used to
+ convert batched tensors to single examples.
+ runtime_checks: Bool. If true, insert runtime checks on the output of
+ `accept_prob_fn`. Using `True` might have a performance impact.
+ name: Optional prefix for ops created by this function.
+ Raises:
+ ValueError: enqueue_many is True and labels doesn't have a batch
+ dimension, or if enqueue_many is False and labels isn't a scalar.
+ ValueError: enqueue_many is True, and batch dimension on data and labels
+ don't match.
+ ValueError: if a zero initial probability class has a nonzero target
+ probability.
+ Returns:
+ A list of tensors of the same length as `tensors`, with batch dimension
+ `batch_size`.
+
+ Example:
+ # Get tensor for a single data and label example.
+ data, label = data_provider.Get(['data', 'label'])
+
+ # Get stratified batch according to data tensor.
+ accept_prob_fn = lambda x: (tf.tanh(x[0]) + 1) / 2
+ data_batch = tf.contrib.training.rejection_sample(
+ [data, label], accept_prob_fn, 16)
+
+ # Run batch through network.
+ ...
+ """
+ with variable_scope.variable_scope(name, 'rejection_sample', tensors):
+ tensor_list = ops.convert_n_to_tensor_or_indexed_slices(tensors)
+ # Reduce the case of a batched example to that of a batch of a single
+ # example by taking a batch of size one.
+ if enqueue_many:
+ # Validate that batch dimension of the input is consistent.
+ tensor_list = _verify_data_inputs(tensor_list)
+
+ # Make a single queue to hold input examples. Reshape output so examples
+ # don't have singleton batch dimension.
+ batched = input_ops.batch(tensor_list,
+ batch_size=1,
+ num_threads=prebatch_threads,
+ capacity=prebatch_capacity,
+ enqueue_many=True)
+ tensor_list = [array_ops.squeeze(x, [0]) for x in batched]
+
+ # Set up a queue containing batches that have the distribution.
+ cur_prob = accept_prob_fn(tensor_list)
+ if runtime_checks:
+ cur_prob = array_ops.identity(control_flow_ops.with_dependencies(
+ [check_ops.assert_less_equal(0.0, cur_prob),
+ check_ops.assert_less_equal(cur_prob, 1.0)],
+ cur_prob), name='prob_with_checks')
+ keep_input = random_ops.random_uniform([]) < cur_prob
+ return _conditional_batch(
+ tensor_list, keep_input, batch_size, num_threads=queue_threads)
+
+
def stratified_sample(tensors, labels, target_probs, batch_size,
init_probs=None, enqueue_many=False, queue_capacity=16,
threads_per_queue=1, name=None):
@@ -145,8 +227,12 @@ def stratified_sample(tensors, labels, target_probs, batch_size,
# Set up second queue containing batches that have the desired class
# proportions.
cur_prob = array_ops.gather(accept_probs, label)
+ keep_input = random_ops.random_uniform([]) < cur_prob
batched = _conditional_batch(
- val_list + [label], cur_prob, batch_size, threads_per_queue)
+ val_list + [label],
+ keep_input,
+ batch_size,
+ num_threads=threads_per_queue)
return batched[:-1], batched[-1]
@@ -260,6 +346,18 @@ def _estimate_data_distribution(labels, num_classes, smoothing_constant=10):
return math_ops.cast(init_prob_estimate, dtypes.float32)
+def _verify_data_inputs(tensor_list):
+ """Verify that batched data inputs are well-formed."""
+ for tensor in tensor_list:
+ # Data tensor should have a batch dimension.
+ tensor_shape = tensor.get_shape().with_rank_at_least(1)
+
+ # Data batch dimensions must be compatible.
+ tensor_shape[0].assert_is_compatible_with(tensor_list[0].get_shape()[0])
+
+ return tensor_list
+
+
def _verify_input(tensor_list, labels, probs_list):
"""Verify that batched inputs are well-formed."""
checked_probs_list = []
@@ -374,16 +472,16 @@ def _calculate_acceptance_probabilities(init_probs, target_probs):
return ratio_l / max_ratio
-def _conditional_batch(tensors, accept_prob, batch_size, queue_threads=10):
+def _conditional_batch(tensors, keep_input, batch_size, num_threads=10):
"""Conditionally enqueue tensors based on accept_prob.
Specifically, enqueue the element if accept_prob > rand_unif([0, 1]).
Args:
tensors: List of tensors to enqueue.
- accept_prob: Acceptance probability per example.
+ keep_input: Bool. Whether to enqueue or not.
batch_size: Size of batch.
- queue_threads: Number of threads enqueuing in the final queue.
+ num_threads: Number of enqueueing threads.
Returns:
List of batched tensors.
@@ -391,7 +489,7 @@ def _conditional_batch(tensors, accept_prob, batch_size, queue_threads=10):
Raises:
ValueError: `accept_prob` isn't 0D.
"""
- accept_prob.get_shape().assert_has_rank(0)
+ keep_input.get_shape().assert_has_rank(0)
# Determine shapes and types of to-be-enqueued-tensors.
shapes_list = []
dtypes_list = []
@@ -409,13 +507,12 @@ def _conditional_batch(tensors, accept_prob, batch_size, queue_threads=10):
# Conditionally enqueue.
# Reshape enqueue op to match no_op's shape.
- eq_tf = math_ops.less(random_ops.random_uniform([]), accept_prob)
conditional_enqueue = control_flow_ops.cond(
- eq_tf,
+ keep_input,
lambda: final_q.enqueue(tensors),
control_flow_ops.no_op)
queue_runner.add_queue_runner(queue_runner.QueueRunner(
- final_q, [conditional_enqueue] * queue_threads))
+ final_q, [conditional_enqueue] * num_threads))
out_tensor = final_q.dequeue_many(batch_size)
# Queues return a single tensor if the list of enqued tensors is one. Since we
diff --git a/tensorflow/contrib/training/python/training/sampling_ops_test.py b/tensorflow/contrib/training/python/training/sampling_ops_test.py
index 2d663d7954..bbc0a284cd 100644
--- a/tensorflow/contrib/training/python/training/sampling_ops_test.py
+++ b/tensorflow/contrib/training/python/training/sampling_ops_test.py
@@ -24,7 +24,7 @@ from tensorflow.contrib.training.python.training import sampling_ops
from tensorflow.python.platform import tf_logging as logging
-class SamplingOpsTest(tf.test.TestCase):
+class StratifiedSampleTest(tf.test.TestCase):
def testGraphBuildAssertionFailures(self):
val = [tf.zeros([1, 3]), tf.ones([1, 5])]
@@ -383,17 +383,79 @@ class SamplingOpsTest(tf.test.TestCase):
self.normalBehaviorHelper(curried_sampler)
+
+class RejectionSampleTest(tf.test.TestCase):
+
+ def testGraphConstructionFailures(self):
+ accept_prob_fn = lambda _: tf.constant(1.0)
+ batch_size = 32
+ # Data must have batch dimension if `enqueue_many` is `True`.
+ with self.assertRaises(ValueError):
+ tf.contrib.training.rejection_sample(
+ [tf.zeros([])], accept_prob_fn, batch_size, enqueue_many=True)
+
+ # Batch dimensions should be equal if `enqueue_many` is `True`.
+ with self.assertRaises(ValueError):
+ tf.contrib.training.rejection_sample(
+ [tf.zeros([5, 1]), tf.zeros([4, 1])], accept_prob_fn, batch_size,
+ enqueue_many=True)
+
+ def testRuntimeFailures(self):
+ prob_ph = tf.placeholder(tf.float32, [])
+ accept_prob_fn = lambda _: prob_ph
+ batch_size = 32
+
+ # Set up graph.
+ tf.set_random_seed(1234)
+ tf.contrib.training.rejection_sample(
+ [tf.zeros([])], accept_prob_fn, batch_size, runtime_checks=True,
+ name='rejection_sample')
+ prob_tensor = tf.get_default_graph().get_tensor_by_name(
+ 'rejection_sample/prob_with_checks:0')
+
+ # Run session that should fail.
+ with self.test_session() as sess:
+ for illegal_prob in [-0.1, 1.1]:
+ with self.assertRaises(tf.errors.InvalidArgumentError):
+ sess.run(prob_tensor, feed_dict={prob_ph: illegal_prob})
+
+ def testNormalBehavior(self):
+ tensor_list = [tf.cond(
+ tf.greater(.5, tf.random_uniform([])),
+ lambda: tf.constant(1.0),
+ lambda: tf.constant(2.0))]
+ accept_prob_fn = lambda x: x[0] - 1.0
+ batch_size = 10
+
+ # Set up graph.
+ sample = tf.contrib.training.rejection_sample(
+ tensor_list, accept_prob_fn, batch_size)
+
+ with self.test_session() as sess:
+ coord = tf.train.Coordinator()
+ threads = tf.train.start_queue_runners(coord=coord)
+
+ for _ in range(5):
+ sample_np = sess.run(sample)[0]
+ self.assertListEqual([2.0] * batch_size, list(sample_np))
+
+ coord.request_stop()
+ coord.join(threads)
+
+
+class ConditionalBatchTest(tf.test.TestCase):
+
def testConditionallyEnqueueAndBatch(self):
tf.set_random_seed(1234)
tensor = tf.cond(
tf.greater(.5, tf.random_uniform([])),
lambda: tf.constant(1.0),
lambda: tf.constant(2.0))
- accept_prob = tensor - 1
+ keep_input = tf.equal(tensor, 2.0)
batch_size = 4
# Set up the test graph.
- [batch] = sampling_ops._conditional_batch([tensor], accept_prob, batch_size) # pylint: disable=protected-access
+ [batch] = sampling_ops._conditional_batch([tensor], keep_input, batch_size) # pylint: disable=protected-access
# Check conditional operation.
with self.test_session():
@@ -411,13 +473,13 @@ class SamplingOpsTest(tf.test.TestCase):
def testConditionallyEnqueueAndBatchTypes(self):
tensor = tf.constant(1.0)
- accept_prob = tensor - 1
+ keep_input = tf.constant(True)
batch_size = 4
# Check that output types are the same for 1 and 2-length input lists.
- output1 = sampling_ops._conditional_batch([tensor], accept_prob, batch_size) # pylint: disable=protected-access
+ output1 = sampling_ops._conditional_batch([tensor], keep_input, batch_size) # pylint: disable=protected-access
output2 = sampling_ops._conditional_batch( # pylint: disable=protected-access
- [tensor, tensor], accept_prob, batch_size)
+ [tensor, tensor], keep_input, batch_size)
self.assertEqual(type(output1), type(output2))