aboutsummaryrefslogtreecommitdiffhomepage
diff options
context:
space:
mode:
authorGravatar A. Unique TensorFlower <gardener@tensorflow.org>2017-01-04 07:18:16 -0800
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2017-01-04 07:25:17 -0800
commit012800e2368de26f677f4cf0093ef8c7b51c3070 (patch)
treed2c15bd4b7dd38a8aa478c2a2dd98119a80e9ab8
parentea296160bc5b7b2c9920a7487650f12d47029338 (diff)
Change for internal compatibility.
-rw-r--r--tensorflow/contrib/training/python/training/sampling_ops.py81
-rw-r--r--tensorflow/contrib/training/python/training/sampling_ops_test.py39
2 files changed, 17 insertions, 103 deletions
diff --git a/tensorflow/contrib/training/python/training/sampling_ops.py b/tensorflow/contrib/training/python/training/sampling_ops.py
index 97279dc457..bf1d2c8cad 100644
--- a/tensorflow/contrib/training/python/training/sampling_ops.py
+++ b/tensorflow/contrib/training/python/training/sampling_ops.py
@@ -22,15 +22,12 @@ from tensorflow.python.framework import ops
from tensorflow.python.ops import array_ops
from tensorflow.python.ops import check_ops
from tensorflow.python.ops import control_flow_ops
-from tensorflow.python.ops import data_flow_ops
from tensorflow.python.ops import logging_ops
from tensorflow.python.ops import math_ops
from tensorflow.python.ops import random_ops
from tensorflow.python.ops import variable_scope
from tensorflow.python.ops import variables
-from tensorflow.python.summary import summary
from tensorflow.python.training import input as input_ops
-from tensorflow.python.training import queue_runner
__all__ = [
'rejection_sample',
@@ -121,9 +118,18 @@ def rejection_sample(tensors,
check_ops.assert_less_equal(cur_prob, 1.0)
], cur_prob),
name='prob_with_checks')
- keep_input = random_ops.random_uniform([]) < cur_prob
- return _conditional_batch(
- tensor_list, keep_input, batch_size, num_threads=queue_threads)
+ minibatch = input_ops.maybe_batch(
+ tensor_list,
+ keep_input=random_ops.random_uniform([]) < cur_prob,
+ batch_size=batch_size,
+ num_threads=queue_threads)
+
+ # Queues return a single tensor if the list of enqued tensors is one. Since
+ # we want the type to always be the same, always return a list.
+ if isinstance(minibatch, ops.Tensor):
+ minibatch = [minibatch]
+
+ return minibatch
def stratified_sample(tensors,
@@ -213,9 +219,8 @@ def stratified_sample(tensors,
math_ops.logical_or(
math_ops.not_equal(init_probs, 0),
math_ops.equal(target_probs, 0))),
- [
- 'All classes with zero initial probability must also have zero target '
- 'probability: ', init_probs, target_probs
+ ['All classes with zero initial probability must also have zero target '
+ 'probability: ', init_probs, target_probs
])
init_probs = control_flow_ops.with_dependencies([assert_op], init_probs)
@@ -244,11 +249,10 @@ def stratified_sample(tensors,
# Set up second queue containing batches that have the desired class
# proportions.
cur_prob = array_ops.gather(accept_probs, label)
- keep_input = random_ops.random_uniform([]) < cur_prob
- batched = _conditional_batch(
+ batched = input_ops.maybe_batch(
val_list + [label],
- keep_input,
- batch_size,
+ keep_input=random_ops.random_uniform([]) < cur_prob,
+ batch_size=batch_size,
num_threads=threads_per_queue)
return batched[:-1], batched[-1]
@@ -416,54 +420,3 @@ def _calculate_acceptance_probabilities(init_probs, target_probs):
# Calculate list of acceptance probabilities.
max_ratio = math_ops.reduce_max(ratio_l)
return ratio_l / max_ratio
-
-
-def _conditional_batch(tensors, keep_input, batch_size, num_threads=10):
- """Conditionally enqueue tensors based on accept_prob.
-
- Specifically, enqueue the element if accept_prob > rand_unif([0, 1]).
-
- Args:
- tensors: List of tensors to enqueue.
- keep_input: Bool. Whether to enqueue or not.
- batch_size: Size of batch.
- num_threads: Number of enqueueing threads.
-
- Returns:
- List of batched tensors.
-
- Raises:
- ValueError: `accept_prob` isn't 0D.
- """
- keep_input.get_shape().assert_has_rank(0)
- # Determine shapes and types of to-be-enqueued-tensors.
- shapes_list = []
- dtypes_list = []
- for tensor in tensors:
- cur_shape = tensor.get_shape()
- cur_shape.assert_is_fully_defined()
- shapes_list.append(cur_shape)
- dtypes_list.append(tensor.dtype)
-
- final_q = data_flow_ops.FIFOQueue(
- capacity=batch_size,
- shapes=shapes_list,
- dtypes=dtypes_list,
- name='batched_queue')
- summary.scalar('queue/%s/size' % final_q.name, final_q.size())
-
- # Conditionally enqueue.
- # Reshape enqueue op to match no_op's shape.
- conditional_enqueue = control_flow_ops.cond(keep_input,
- lambda: final_q.enqueue(tensors),
- control_flow_ops.no_op)
- queue_runner.add_queue_runner(
- queue_runner.QueueRunner(final_q, [conditional_enqueue] * num_threads))
-
- out_tensor = final_q.dequeue_many(batch_size)
- # Queues return a single tensor if the list of enqued tensors is one. Since we
- # want the type to be the same in all cases, always return a list.
- if isinstance(out_tensor, ops.Tensor):
- out_tensor = [out_tensor]
-
- return out_tensor
diff --git a/tensorflow/contrib/training/python/training/sampling_ops_test.py b/tensorflow/contrib/training/python/training/sampling_ops_test.py
index 1a34ee0953..bf7fb4fd48 100644
--- a/tensorflow/contrib/training/python/training/sampling_ops_test.py
+++ b/tensorflow/contrib/training/python/training/sampling_ops_test.py
@@ -405,44 +405,5 @@ class RejectionSampleTest(test.TestCase):
coord.join(threads)
-class ConditionalBatchTest(test.TestCase):
-
- def testConditionallyEnqueueAndBatch(self):
- random_seed.set_random_seed(1234)
- tensor = control_flow_ops.cond(
- math_ops.greater(.5, random_ops.random_uniform([])),
- lambda: constant_op.constant(1.0), lambda: constant_op.constant(2.0))
- keep_input = math_ops.equal(tensor, 2.0)
- batch_size = 4
-
- # Set up the test graph.
- [batch] = sampling_ops._conditional_batch([tensor], keep_input, batch_size) # pylint: disable=protected-access
-
- # Check conditional operation.
- with self.test_session():
- coord = coordinator.Coordinator()
- threads = queue_runner_impl.start_queue_runners(coord=coord)
-
- batch_np = batch.eval()
-
- coord.request_stop()
- coord.join(threads)
-
- # Check that all elements in batch come from tensors with acceptance prob
- # 1, so that none come from acceptance prob 0.
- self.assertListEqual(list(batch_np), [2.0] * batch_size)
-
- def testConditionallyEnqueueAndBatchTypes(self):
- tensor = constant_op.constant(1.0)
- keep_input = constant_op.constant(True)
- batch_size = 4
-
- # Check that output types are the same for 1 and 2-length input lists.
- output1 = sampling_ops._conditional_batch([tensor], keep_input, batch_size) # pylint: disable=protected-access
- output2 = sampling_ops._conditional_batch( # pylint: disable=protected-access
- [tensor, tensor], keep_input, batch_size)
- self.assertEqual(type(output1), type(output2))
-
-
if __name__ == '__main__':
test.main()