aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/contrib
diff options
context:
space:
mode:
authorGravatar A. Unique TensorFlower <gardener@tensorflow.org>2016-07-18 11:49:42 -0800
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2016-07-18 13:04:09 -0700
commit3e4ffd1deb3922bd9f2ac63589ffe5ae96400328 (patch)
tree60c863e701aa903644b78377e400d1a304059aba /tensorflow/contrib
parent74014d49384531b9209ad38ed7555872fcdfb4bf (diff)
Minor simplification in sampling_ops.py. Effectively a noop.
Change: 127747773
Diffstat (limited to 'tensorflow/contrib')
-rw-r--r--tensorflow/contrib/framework/python/ops/sampling_ops.py6
1 files changed, 2 insertions, 4 deletions
diff --git a/tensorflow/contrib/framework/python/ops/sampling_ops.py b/tensorflow/contrib/framework/python/ops/sampling_ops.py
index 5a98f5e0dd..d44fe3b3f6 100644
--- a/tensorflow/contrib/framework/python/ops/sampling_ops.py
+++ b/tensorflow/contrib/framework/python/ops/sampling_ops.py
@@ -128,10 +128,8 @@ def stratified_sample(tensors, labels, init_probs, target_probs, batch_size,
num_threads=threads_per_queue,
capacity=queue_capacity,
enqueue_many=True)
- val_list = [array_ops.reshape(x, y.get_shape().with_rank_at_least(1)[1:])
- for x, y in zip(batched[:-1], tensor_list)]
- label = array_ops.reshape(
- batched[-1], labels.get_shape().with_rank_at_least(1)[1:])
+ val_list = [array_ops.squeeze(x, [0]) for x in batched[:-1]]
+ label = array_ops.squeeze(batched[-1], [0])
# Set up second queue containing batches that have the desired class
# proportions.