aboutsummaryrefslogtreecommitdiffhomepage
diff options
context:
space:
mode:
authorGravatar David G. Andersen <dga@google.com>2016-04-08 15:51:41 -0800
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2016-04-08 17:02:46 -0700
commit4d9ec5ece5771a1982352574ce2cad587644fada (patch)
tree89f5fcc5a083cd1dbab03f90daf22f40ad794d2a
parentf77c9fb707d12f5354a399055b6db5ebd5bc5d1f (diff)
More comprehensively enforcing Tensor MaxDimensions limit.
Change: 119423048
-rw-r--r--tensorflow/core/framework/tensor_shape.cc9
-rw-r--r--tensorflow/core/framework/tensor_shape.h1
-rw-r--r--tensorflow/core/framework/tensor_shape_test.cc16
3 files changed, 26 insertions, 0 deletions
diff --git a/tensorflow/core/framework/tensor_shape.cc b/tensorflow/core/framework/tensor_shape.cc
index 534b73575b..ae7c34bd93 100644
--- a/tensorflow/core/framework/tensor_shape.cc
+++ b/tensorflow/core/framework/tensor_shape.cc
@@ -44,6 +44,7 @@ void TensorShape::CheckDimsAtLeast(int NDIMS) const {
bool TensorShape::IsValid(const TensorShapeProto& proto) {
int64 num_elements = 1;
+ if (proto.dim().size() > MaxDimensions()) return false;
for (const auto& d : proto.dim()) {
if (d.size() < 0) return false;
num_elements *= d.size();
@@ -54,6 +55,10 @@ bool TensorShape::IsValid(const TensorShapeProto& proto) {
Status TensorShape::IsValidShape(const TensorShapeProto& proto) {
int64 num_elements = 1;
+ if (proto.dim().size() > MaxDimensions()) {
+ return errors::InvalidArgument("Shape ", DebugString(proto),
+ " has too many dimensions");
+ }
for (const auto& d : proto.dim()) {
if (d.size() < 0) {
return errors::InvalidArgument("Shape ", DebugString(proto),
@@ -214,6 +219,7 @@ void TensorShape::InsertDim(int d, int64 size) {
CHECK_GE(d, 0);
CHECK_LE(d, dims());
CHECK_GE(size, 0);
+ CHECK_LT(dims(), MaxDimensions());
gtl::InlinedVector<int64, 8> vals;
AppendTo(*this, &vals);
vals.insert(vals.begin() + d, size);
@@ -341,6 +347,9 @@ bool TensorShapeUtils::StartsWith(const TensorShape& shape,
template <typename T>
static inline Status MakeShapeHelper(const T* dims, int n, TensorShape* out) {
*out = TensorShape();
+ if (n > TensorShape::MaxDimensions()) {
+ return errors::InvalidArgument("Too many dimensions");
+ }
for (int i = 0; i < n; ++i) {
const T dim = internal::SubtleMustCopy(dims[i]);
if (dim >= 0) {
diff --git a/tensorflow/core/framework/tensor_shape.h b/tensorflow/core/framework/tensor_shape.h
index b6f20c73e0..84947e308a 100644
--- a/tensorflow/core/framework/tensor_shape.h
+++ b/tensorflow/core/framework/tensor_shape.h
@@ -280,6 +280,7 @@ template <int NDIMS>
Eigen::DSizes<Eigen::DenseIndex, NDIMS> TensorShape::AsEigenDSizesWithPadding()
const {
CheckDimsAtLeast(NDIMS);
+ static_assert(NDIMS <= TensorShape::MaxDimensions(), "Too many dimensions");
Eigen::DSizes<Eigen::DenseIndex, NDIMS> dsizes;
for (int d = 0; d < dims(); d++) {
dsizes[d] = dim_size(d);
diff --git a/tensorflow/core/framework/tensor_shape_test.cc b/tensorflow/core/framework/tensor_shape_test.cc
index f47d2f9ac3..5eeaeb61da 100644
--- a/tensorflow/core/framework/tensor_shape_test.cc
+++ b/tensorflow/core/framework/tensor_shape_test.cc
@@ -15,6 +15,7 @@ limitations under the License.
#include "tensorflow/core/framework/tensor_shape.h"
+#include "tensorflow/core/lib/core/status_test_util.h"
#include "tensorflow/core/lib/random/simple_philox.h"
#include "tensorflow/core/lib/strings/str_util.h"
#include "tensorflow/core/lib/strings/strcat.h"
@@ -87,6 +88,21 @@ TEST(TensorShapeTest, InvalidShapeProto) {
EXPECT_FALSE(TensorShape::IsValid(proto));
}
+TEST(TensorShapeTest, TooManyDimsProto) {
+ TensorShapeProto proto;
+ // Deliberate redundancy to ensure that both paths work.
+ EXPECT_TRUE(TensorShape::IsValid(proto));
+ TF_EXPECT_OK(TensorShape::IsValidShape(proto));
+ for (int i = 0; i < TensorShape::MaxDimensions(); i++) {
+ proto.add_dim()->set_size(1);
+ }
+ EXPECT_TRUE(TensorShape::IsValid(proto));
+ TF_EXPECT_OK(TensorShape::IsValidShape(proto));
+ proto.add_dim()->set_size(1);
+ EXPECT_FALSE(TensorShape::IsValid(proto));
+ EXPECT_FALSE(TensorShape::IsValidShape(proto).ok());
+}
+
TEST(TensorShapeTest, SetDimForEmptyTensor) {
TensorShape s({10, 5, 20});
EXPECT_EQ(1000, s.num_elements());