aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/contrib
diff options
context:
space:
mode:
Diffstat (limited to 'tensorflow/contrib')
-rw-r--r--tensorflow/contrib/framework/python/ops/sampling_ops.py6
1 files changed, 2 insertions, 4 deletions
diff --git a/tensorflow/contrib/framework/python/ops/sampling_ops.py b/tensorflow/contrib/framework/python/ops/sampling_ops.py
index 5a98f5e0dd..d44fe3b3f6 100644
--- a/tensorflow/contrib/framework/python/ops/sampling_ops.py
+++ b/tensorflow/contrib/framework/python/ops/sampling_ops.py
@@ -128,10 +128,8 @@ def stratified_sample(tensors, labels, init_probs, target_probs, batch_size,
num_threads=threads_per_queue,
capacity=queue_capacity,
enqueue_many=True)
- val_list = [array_ops.reshape(x, y.get_shape().with_rank_at_least(1)[1:])
- for x, y in zip(batched[:-1], tensor_list)]
- label = array_ops.reshape(
- batched[-1], labels.get_shape().with_rank_at_least(1)[1:])
+ val_list = [array_ops.squeeze(x, [0]) for x in batched[:-1]]
+ label = array_ops.squeeze(batched[-1], [0])
# Set up second queue containing batches that have the desired class
# proportions.