diff options
author | 2017-01-04 07:18:16 -0800 | |
---|---|---|
committer | 2017-01-04 07:25:17 -0800 | |
commit | 012800e2368de26f677f4cf0093ef8c7b51c3070 (patch) | |
tree | d2c15bd4b7dd38a8aa478c2a2dd98119a80e9ab8 | |
parent | ea296160bc5b7b2c9920a7487650f12d47029338 (diff) |
Change for internal compatibility.
-rw-r--r-- | tensorflow/contrib/training/python/training/sampling_ops.py | 81 | ||||
-rw-r--r-- | tensorflow/contrib/training/python/training/sampling_ops_test.py | 39 |
2 files changed, 17 insertions, 103 deletions
diff --git a/tensorflow/contrib/training/python/training/sampling_ops.py b/tensorflow/contrib/training/python/training/sampling_ops.py index 97279dc457..bf1d2c8cad 100644 --- a/tensorflow/contrib/training/python/training/sampling_ops.py +++ b/tensorflow/contrib/training/python/training/sampling_ops.py @@ -22,15 +22,12 @@ from tensorflow.python.framework import ops from tensorflow.python.ops import array_ops from tensorflow.python.ops import check_ops from tensorflow.python.ops import control_flow_ops -from tensorflow.python.ops import data_flow_ops from tensorflow.python.ops import logging_ops from tensorflow.python.ops import math_ops from tensorflow.python.ops import random_ops from tensorflow.python.ops import variable_scope from tensorflow.python.ops import variables -from tensorflow.python.summary import summary from tensorflow.python.training import input as input_ops -from tensorflow.python.training import queue_runner __all__ = [ 'rejection_sample', @@ -121,9 +118,18 @@ def rejection_sample(tensors, check_ops.assert_less_equal(cur_prob, 1.0) ], cur_prob), name='prob_with_checks') - keep_input = random_ops.random_uniform([]) < cur_prob - return _conditional_batch( - tensor_list, keep_input, batch_size, num_threads=queue_threads) + minibatch = input_ops.maybe_batch( + tensor_list, + keep_input=random_ops.random_uniform([]) < cur_prob, + batch_size=batch_size, + num_threads=queue_threads) + + # Queues return a single tensor if the list of enqued tensors is one. Since + # we want the type to always be the same, always return a list. + if isinstance(minibatch, ops.Tensor): + minibatch = [minibatch] + + return minibatch def stratified_sample(tensors, @@ -213,9 +219,8 @@ def stratified_sample(tensors, math_ops.logical_or( math_ops.not_equal(init_probs, 0), math_ops.equal(target_probs, 0))), - [ - 'All classes with zero initial probability must also have zero target ' - 'probability: ', init_probs, target_probs + ['All classes with zero initial probability must also have zero target ' + 'probability: ', init_probs, target_probs ]) init_probs = control_flow_ops.with_dependencies([assert_op], init_probs) @@ -244,11 +249,10 @@ def stratified_sample(tensors, # Set up second queue containing batches that have the desired class # proportions. cur_prob = array_ops.gather(accept_probs, label) - keep_input = random_ops.random_uniform([]) < cur_prob - batched = _conditional_batch( + batched = input_ops.maybe_batch( val_list + [label], - keep_input, - batch_size, + keep_input=random_ops.random_uniform([]) < cur_prob, + batch_size=batch_size, num_threads=threads_per_queue) return batched[:-1], batched[-1] @@ -416,54 +420,3 @@ def _calculate_acceptance_probabilities(init_probs, target_probs): # Calculate list of acceptance probabilities. max_ratio = math_ops.reduce_max(ratio_l) return ratio_l / max_ratio - - -def _conditional_batch(tensors, keep_input, batch_size, num_threads=10): - """Conditionally enqueue tensors based on accept_prob. - - Specifically, enqueue the element if accept_prob > rand_unif([0, 1]). - - Args: - tensors: List of tensors to enqueue. - keep_input: Bool. Whether to enqueue or not. - batch_size: Size of batch. - num_threads: Number of enqueueing threads. - - Returns: - List of batched tensors. - - Raises: - ValueError: `accept_prob` isn't 0D. - """ - keep_input.get_shape().assert_has_rank(0) - # Determine shapes and types of to-be-enqueued-tensors. - shapes_list = [] - dtypes_list = [] - for tensor in tensors: - cur_shape = tensor.get_shape() - cur_shape.assert_is_fully_defined() - shapes_list.append(cur_shape) - dtypes_list.append(tensor.dtype) - - final_q = data_flow_ops.FIFOQueue( - capacity=batch_size, - shapes=shapes_list, - dtypes=dtypes_list, - name='batched_queue') - summary.scalar('queue/%s/size' % final_q.name, final_q.size()) - - # Conditionally enqueue. - # Reshape enqueue op to match no_op's shape. - conditional_enqueue = control_flow_ops.cond(keep_input, - lambda: final_q.enqueue(tensors), - control_flow_ops.no_op) - queue_runner.add_queue_runner( - queue_runner.QueueRunner(final_q, [conditional_enqueue] * num_threads)) - - out_tensor = final_q.dequeue_many(batch_size) - # Queues return a single tensor if the list of enqued tensors is one. Since we - # want the type to be the same in all cases, always return a list. - if isinstance(out_tensor, ops.Tensor): - out_tensor = [out_tensor] - - return out_tensor diff --git a/tensorflow/contrib/training/python/training/sampling_ops_test.py b/tensorflow/contrib/training/python/training/sampling_ops_test.py index 1a34ee0953..bf7fb4fd48 100644 --- a/tensorflow/contrib/training/python/training/sampling_ops_test.py +++ b/tensorflow/contrib/training/python/training/sampling_ops_test.py @@ -405,44 +405,5 @@ class RejectionSampleTest(test.TestCase): coord.join(threads) -class ConditionalBatchTest(test.TestCase): - - def testConditionallyEnqueueAndBatch(self): - random_seed.set_random_seed(1234) - tensor = control_flow_ops.cond( - math_ops.greater(.5, random_ops.random_uniform([])), - lambda: constant_op.constant(1.0), lambda: constant_op.constant(2.0)) - keep_input = math_ops.equal(tensor, 2.0) - batch_size = 4 - - # Set up the test graph. - [batch] = sampling_ops._conditional_batch([tensor], keep_input, batch_size) # pylint: disable=protected-access - - # Check conditional operation. - with self.test_session(): - coord = coordinator.Coordinator() - threads = queue_runner_impl.start_queue_runners(coord=coord) - - batch_np = batch.eval() - - coord.request_stop() - coord.join(threads) - - # Check that all elements in batch come from tensors with acceptance prob - # 1, so that none come from acceptance prob 0. - self.assertListEqual(list(batch_np), [2.0] * batch_size) - - def testConditionallyEnqueueAndBatchTypes(self): - tensor = constant_op.constant(1.0) - keep_input = constant_op.constant(True) - batch_size = 4 - - # Check that output types are the same for 1 and 2-length input lists. - output1 = sampling_ops._conditional_batch([tensor], keep_input, batch_size) # pylint: disable=protected-access - output2 = sampling_ops._conditional_batch( # pylint: disable=protected-access - [tensor, tensor], keep_input, batch_size) - self.assertEqual(type(output1), type(output2)) - - if __name__ == '__main__': test.main() |