diff options
author | 2016-12-08 12:24:22 -0800 | |
---|---|---|
committer | 2016-12-08 12:44:33 -0800 | |
commit | 7b94349a7ea79ab63ad9ab931bd2c52c8e645911 (patch) | |
tree | d7821688efcc470d06dbc685ef6095b9cd02d2dd | |
parent | 22aaf4e3539ea489d0609055b9e97e6dd772cefe (diff) |
Fix bug in handling of explicitly specified num parameter of split_v
and also add associated test case.
Change: 141469479
-rw-r--r-- | tensorflow/python/kernel_tests/split_op_test.py | 18 | ||||
-rw-r--r-- | tensorflow/python/ops/array_ops.py | 12 |
2 files changed, 24 insertions, 6 deletions
diff --git a/tensorflow/python/kernel_tests/split_op_test.py b/tensorflow/python/kernel_tests/split_op_test.py index 94add940d1..60a1f4bcf3 100644 --- a/tensorflow/python/kernel_tests/split_op_test.py +++ b/tensorflow/python/kernel_tests/split_op_test.py @@ -24,6 +24,24 @@ import tensorflow as tf class SplitVOpTest(tf.test.TestCase): + def testExplicitNum(self): + size_splits = tf.placeholder(dtype=tf.int32, shape=[None]) + + value = [1, 2, 3, 4, 5, 6, 7, 8, 9, 10] + + with self.test_session(use_gpu=False) as sess: + with self.assertRaises(ValueError) as context: + sess.run(tf.split_v(value, size_splits), {size_splits: [2, 2, 6]}) + + self.assertTrue("Cannot infer num from shape" in str(context.exception)) + + result = sess.run(tf.split_v(value, size_splits, num=3), + {size_splits: [2, 2, 6]}) + + self.assertAllEqual(result[0], value[0:2]) + self.assertAllEqual(result[1], value[2:4]) + self.assertAllEqual(result[2], value[4:]) + def testListOfScalarTensors(self): a = tf.to_int32(5) b = tf.to_int32(6) diff --git a/tensorflow/python/ops/array_ops.py b/tensorflow/python/ops/array_ops.py index e25a8449cb..efbbf2ed93 100644 --- a/tensorflow/python/ops/array_ops.py +++ b/tensorflow/python/ops/array_ops.py @@ -1294,7 +1294,7 @@ def split_v(value=None, ```python # 'value' is a tensor with shape [5, 30] # Split 'value' into 3 tensors with sizes [4, 15, 11] along dimension 1 - split0, split1, split2 = tf.split_v(1, [4, 15, 11], value) + split0, split1, split2 = tf.split_v(value, [4, 15, 11], 1) tf.shape(split0) ==> [5, 4] tf.shape(split1) ==> [5, 15] tf.shape(split2) ==> [5, 11] @@ -1329,17 +1329,17 @@ def split_v(value=None, return gen_array_ops._split( split_dim=axis, num_split=num_or_size_splits, value=value, name=name) else: + size_splits = ops.convert_to_tensor(num_or_size_splits) if num is None: - size_splits = ops.convert_to_tensor(num_or_size_splits) size_splits_shape = size_splits.get_shape() - num = size_splits_shape.dims - if num is None: - raise ValueError("Cannot infer num from shape %s" % value_shape) + num = size_splits_shape.dims[0] + if num._value is None: + raise ValueError("Cannot infer num from shape %s" % num_or_size_splits) return gen_array_ops._split_v( value=value, size_splits=size_splits, split_dim=axis, - num_split=num[0], + num_split=num, name=name) |