aboutsummaryrefslogtreecommitdiffhomepage
diff options
context:
space:
mode:
authorGravatar A. Unique TensorFlower <gardener@tensorflow.org>2016-12-08 12:24:22 -0800
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2016-12-08 12:44:33 -0800
commit7b94349a7ea79ab63ad9ab931bd2c52c8e645911 (patch)
treed7821688efcc470d06dbc685ef6095b9cd02d2dd
parent22aaf4e3539ea489d0609055b9e97e6dd772cefe (diff)
Fix bug in handling of explicitly specified num parameter of split_v
and also add associated test case. Change: 141469479
-rw-r--r--tensorflow/python/kernel_tests/split_op_test.py18
-rw-r--r--tensorflow/python/ops/array_ops.py12
2 files changed, 24 insertions, 6 deletions
diff --git a/tensorflow/python/kernel_tests/split_op_test.py b/tensorflow/python/kernel_tests/split_op_test.py
index 94add940d1..60a1f4bcf3 100644
--- a/tensorflow/python/kernel_tests/split_op_test.py
+++ b/tensorflow/python/kernel_tests/split_op_test.py
@@ -24,6 +24,24 @@ import tensorflow as tf
class SplitVOpTest(tf.test.TestCase):
+ def testExplicitNum(self):
+ size_splits = tf.placeholder(dtype=tf.int32, shape=[None])
+
+ value = [1, 2, 3, 4, 5, 6, 7, 8, 9, 10]
+
+ with self.test_session(use_gpu=False) as sess:
+ with self.assertRaises(ValueError) as context:
+ sess.run(tf.split_v(value, size_splits), {size_splits: [2, 2, 6]})
+
+ self.assertTrue("Cannot infer num from shape" in str(context.exception))
+
+ result = sess.run(tf.split_v(value, size_splits, num=3),
+ {size_splits: [2, 2, 6]})
+
+ self.assertAllEqual(result[0], value[0:2])
+ self.assertAllEqual(result[1], value[2:4])
+ self.assertAllEqual(result[2], value[4:])
+
def testListOfScalarTensors(self):
a = tf.to_int32(5)
b = tf.to_int32(6)
diff --git a/tensorflow/python/ops/array_ops.py b/tensorflow/python/ops/array_ops.py
index e25a8449cb..efbbf2ed93 100644
--- a/tensorflow/python/ops/array_ops.py
+++ b/tensorflow/python/ops/array_ops.py
@@ -1294,7 +1294,7 @@ def split_v(value=None,
```python
# 'value' is a tensor with shape [5, 30]
# Split 'value' into 3 tensors with sizes [4, 15, 11] along dimension 1
- split0, split1, split2 = tf.split_v(1, [4, 15, 11], value)
+ split0, split1, split2 = tf.split_v(value, [4, 15, 11], 1)
tf.shape(split0) ==> [5, 4]
tf.shape(split1) ==> [5, 15]
tf.shape(split2) ==> [5, 11]
@@ -1329,17 +1329,17 @@ def split_v(value=None,
return gen_array_ops._split(
split_dim=axis, num_split=num_or_size_splits, value=value, name=name)
else:
+ size_splits = ops.convert_to_tensor(num_or_size_splits)
if num is None:
- size_splits = ops.convert_to_tensor(num_or_size_splits)
size_splits_shape = size_splits.get_shape()
- num = size_splits_shape.dims
- if num is None:
- raise ValueError("Cannot infer num from shape %s" % value_shape)
+ num = size_splits_shape.dims[0]
+ if num._value is None:
+ raise ValueError("Cannot infer num from shape %s" % num_or_size_splits)
return gen_array_ops._split_v(
value=value,
size_splits=size_splits,
split_dim=axis,
- num_split=num[0],
+ num_split=num,
name=name)