diff options
author | 2017-04-05 13:11:42 -0800 | |
---|---|---|
committer | 2017-04-05 14:25:02 -0700 | |
commit | fd2bfed807979442e80f27a46df69ca7b579d42a (patch) | |
tree | 6a59d6a985f57c8231d844429dbdbfd673cb0523 | |
parent | e018a983a5abb5d849f1ec5393fb0834c7d78c8f (diff) |
Assert rank is at least equal to new_rank for `_sparse_inner_flatten`.
Change: 152303319
-rw-r--r-- | tensorflow/contrib/layers/python/layers/layers.py | 4 | ||||
-rw-r--r-- | tensorflow/contrib/layers/python/layers/layers_test.py | 24 |
2 files changed, 28 insertions, 0 deletions
diff --git a/tensorflow/contrib/layers/python/layers/layers.py b/tensorflow/contrib/layers/python/layers/layers.py index 0140f6d0d3..bf2a372f04 100644 --- a/tensorflow/contrib/layers/python/layers/layers.py +++ b/tensorflow/contrib/layers/python/layers/layers.py @@ -1241,6 +1241,10 @@ def flatten(inputs, def _sparse_inner_flatten(inputs, new_rank): """Helper function for `inner_flatten`.""" + inputs_rank = inputs.dense_shape.get_shape().as_list()[0] + if inputs_rank < new_rank: + raise ValueError('inputs has rank less than new_rank.') + outer_dimensions = inputs.dense_shape[:new_rank - 1] inner_dimensions = inputs.dense_shape[new_rank - 1:] new_shape = array_ops.concat((outer_dimensions, diff --git a/tensorflow/contrib/layers/python/layers/layers_test.py b/tensorflow/contrib/layers/python/layers/layers_test.py index 2b170e92ba..4cdc8ca005 100644 --- a/tensorflow/contrib/layers/python/layers/layers_test.py +++ b/tensorflow/contrib/layers/python/layers/layers_test.py @@ -1465,6 +1465,30 @@ class PartialFlattenTest(test.TestCase): flattened5 = _layers._inner_flatten(inputs, 5) self.assertEqual([2, None, 4, None, 30], flattened5.get_shape().as_list()) + def testDenseFlattenRankAssertion(self): + """Test `_inner_flatten` rank assertion for dense tensors.""" + shape = [2, 3] + new_rank = 3 + inputs = array_ops.placeholder(dtypes.int32) + inputs.set_shape(shape) + + with self.assertRaisesRegexp(ValueError, + 'inputs has rank less than new_rank'): + _layers._inner_flatten(inputs, new_rank) + + def testSparseFlattenRankAssertion(self): + """Test `_inner_flatten` rank assertion for sparse tensors.""" + shape = [2, 3] + new_rank = 3 + np.random.seed(10301) + random_ = np.random.rand(*shape) + indices, values, _ = _sparsify(random_) + inputs = sparse_tensor.SparseTensor(indices, values, shape) + + with self.assertRaisesRegexp(ValueError, + 'inputs has rank less than new_rank'): + _layers._inner_flatten(inputs, new_rank) + class FCTest(test.TestCase): |