aboutsummaryrefslogtreecommitdiffhomepage
diff options
context:
space:
mode:
authorGravatar A. Unique TensorFlower <gardener@tensorflow.org>2017-04-05 13:11:42 -0800
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2017-04-05 14:25:02 -0700
commitfd2bfed807979442e80f27a46df69ca7b579d42a (patch)
tree6a59d6a985f57c8231d844429dbdbfd673cb0523
parente018a983a5abb5d849f1ec5393fb0834c7d78c8f (diff)
Assert rank is at least equal to new_rank for `_sparse_inner_flatten`.
Change: 152303319
-rw-r--r--tensorflow/contrib/layers/python/layers/layers.py4
-rw-r--r--tensorflow/contrib/layers/python/layers/layers_test.py24
2 files changed, 28 insertions, 0 deletions
diff --git a/tensorflow/contrib/layers/python/layers/layers.py b/tensorflow/contrib/layers/python/layers/layers.py
index 0140f6d0d3..bf2a372f04 100644
--- a/tensorflow/contrib/layers/python/layers/layers.py
+++ b/tensorflow/contrib/layers/python/layers/layers.py
@@ -1241,6 +1241,10 @@ def flatten(inputs,
def _sparse_inner_flatten(inputs, new_rank):
"""Helper function for `inner_flatten`."""
+ inputs_rank = inputs.dense_shape.get_shape().as_list()[0]
+ if inputs_rank < new_rank:
+ raise ValueError('inputs has rank less than new_rank.')
+
outer_dimensions = inputs.dense_shape[:new_rank - 1]
inner_dimensions = inputs.dense_shape[new_rank - 1:]
new_shape = array_ops.concat((outer_dimensions,
diff --git a/tensorflow/contrib/layers/python/layers/layers_test.py b/tensorflow/contrib/layers/python/layers/layers_test.py
index 2b170e92ba..4cdc8ca005 100644
--- a/tensorflow/contrib/layers/python/layers/layers_test.py
+++ b/tensorflow/contrib/layers/python/layers/layers_test.py
@@ -1465,6 +1465,30 @@ class PartialFlattenTest(test.TestCase):
flattened5 = _layers._inner_flatten(inputs, 5)
self.assertEqual([2, None, 4, None, 30], flattened5.get_shape().as_list())
+ def testDenseFlattenRankAssertion(self):
+ """Test `_inner_flatten` rank assertion for dense tensors."""
+ shape = [2, 3]
+ new_rank = 3
+ inputs = array_ops.placeholder(dtypes.int32)
+ inputs.set_shape(shape)
+
+ with self.assertRaisesRegexp(ValueError,
+ 'inputs has rank less than new_rank'):
+ _layers._inner_flatten(inputs, new_rank)
+
+ def testSparseFlattenRankAssertion(self):
+ """Test `_inner_flatten` rank assertion for sparse tensors."""
+ shape = [2, 3]
+ new_rank = 3
+ np.random.seed(10301)
+ random_ = np.random.rand(*shape)
+ indices, values, _ = _sparsify(random_)
+ inputs = sparse_tensor.SparseTensor(indices, values, shape)
+
+ with self.assertRaisesRegexp(ValueError,
+ 'inputs has rank less than new_rank'):
+ _layers._inner_flatten(inputs, new_rank)
+
class FCTest(test.TestCase):