diff options
author | David G. Andersen <dga@google.com> | 2016-04-08 15:51:41 -0800 |
---|---|---|
committer | TensorFlower Gardener <gardener@tensorflow.org> | 2016-04-08 17:02:46 -0700 |
commit | 4d9ec5ece5771a1982352574ce2cad587644fada (patch) | |
tree | 89f5fcc5a083cd1dbab03f90daf22f40ad794d2a | |
parent | f77c9fb707d12f5354a399055b6db5ebd5bc5d1f (diff) |
More comprehensively enforcing Tensor MaxDimensions limit.
Change: 119423048
-rw-r--r-- | tensorflow/core/framework/tensor_shape.cc | 9 | ||||
-rw-r--r-- | tensorflow/core/framework/tensor_shape.h | 1 | ||||
-rw-r--r-- | tensorflow/core/framework/tensor_shape_test.cc | 16 |
3 files changed, 26 insertions, 0 deletions
diff --git a/tensorflow/core/framework/tensor_shape.cc b/tensorflow/core/framework/tensor_shape.cc index 534b73575b..ae7c34bd93 100644 --- a/tensorflow/core/framework/tensor_shape.cc +++ b/tensorflow/core/framework/tensor_shape.cc @@ -44,6 +44,7 @@ void TensorShape::CheckDimsAtLeast(int NDIMS) const { bool TensorShape::IsValid(const TensorShapeProto& proto) { int64 num_elements = 1; + if (proto.dim().size() > MaxDimensions()) return false; for (const auto& d : proto.dim()) { if (d.size() < 0) return false; num_elements *= d.size(); @@ -54,6 +55,10 @@ bool TensorShape::IsValid(const TensorShapeProto& proto) { Status TensorShape::IsValidShape(const TensorShapeProto& proto) { int64 num_elements = 1; + if (proto.dim().size() > MaxDimensions()) { + return errors::InvalidArgument("Shape ", DebugString(proto), + " has too many dimensions"); + } for (const auto& d : proto.dim()) { if (d.size() < 0) { return errors::InvalidArgument("Shape ", DebugString(proto), @@ -214,6 +219,7 @@ void TensorShape::InsertDim(int d, int64 size) { CHECK_GE(d, 0); CHECK_LE(d, dims()); CHECK_GE(size, 0); + CHECK_LT(dims(), MaxDimensions()); gtl::InlinedVector<int64, 8> vals; AppendTo(*this, &vals); vals.insert(vals.begin() + d, size); @@ -341,6 +347,9 @@ bool TensorShapeUtils::StartsWith(const TensorShape& shape, template <typename T> static inline Status MakeShapeHelper(const T* dims, int n, TensorShape* out) { *out = TensorShape(); + if (n > TensorShape::MaxDimensions()) { + return errors::InvalidArgument("Too many dimensions"); + } for (int i = 0; i < n; ++i) { const T dim = internal::SubtleMustCopy(dims[i]); if (dim >= 0) { diff --git a/tensorflow/core/framework/tensor_shape.h b/tensorflow/core/framework/tensor_shape.h index b6f20c73e0..84947e308a 100644 --- a/tensorflow/core/framework/tensor_shape.h +++ b/tensorflow/core/framework/tensor_shape.h @@ -280,6 +280,7 @@ template <int NDIMS> Eigen::DSizes<Eigen::DenseIndex, NDIMS> TensorShape::AsEigenDSizesWithPadding() const { CheckDimsAtLeast(NDIMS); + static_assert(NDIMS <= TensorShape::MaxDimensions(), "Too many dimensions"); Eigen::DSizes<Eigen::DenseIndex, NDIMS> dsizes; for (int d = 0; d < dims(); d++) { dsizes[d] = dim_size(d); diff --git a/tensorflow/core/framework/tensor_shape_test.cc b/tensorflow/core/framework/tensor_shape_test.cc index f47d2f9ac3..5eeaeb61da 100644 --- a/tensorflow/core/framework/tensor_shape_test.cc +++ b/tensorflow/core/framework/tensor_shape_test.cc @@ -15,6 +15,7 @@ limitations under the License. #include "tensorflow/core/framework/tensor_shape.h" +#include "tensorflow/core/lib/core/status_test_util.h" #include "tensorflow/core/lib/random/simple_philox.h" #include "tensorflow/core/lib/strings/str_util.h" #include "tensorflow/core/lib/strings/strcat.h" @@ -87,6 +88,21 @@ TEST(TensorShapeTest, InvalidShapeProto) { EXPECT_FALSE(TensorShape::IsValid(proto)); } +TEST(TensorShapeTest, TooManyDimsProto) { + TensorShapeProto proto; + // Deliberate redundancy to ensure that both paths work. + EXPECT_TRUE(TensorShape::IsValid(proto)); + TF_EXPECT_OK(TensorShape::IsValidShape(proto)); + for (int i = 0; i < TensorShape::MaxDimensions(); i++) { + proto.add_dim()->set_size(1); + } + EXPECT_TRUE(TensorShape::IsValid(proto)); + TF_EXPECT_OK(TensorShape::IsValidShape(proto)); + proto.add_dim()->set_size(1); + EXPECT_FALSE(TensorShape::IsValid(proto)); + EXPECT_FALSE(TensorShape::IsValidShape(proto).ok()); +} + TEST(TensorShapeTest, SetDimForEmptyTensor) { TensorShape s({10, 5, 20}); EXPECT_EQ(1000, s.num_elements()); |