aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/core/framework/shape_inference.cc
diff options
context:
space:
mode:
Diffstat (limited to 'tensorflow/core/framework/shape_inference.cc')
-rw-r--r--tensorflow/core/framework/shape_inference.cc22
1 files changed, 16 insertions, 6 deletions
diff --git a/tensorflow/core/framework/shape_inference.cc b/tensorflow/core/framework/shape_inference.cc
index 449d8f55f5..a990dc2f04 100644
--- a/tensorflow/core/framework/shape_inference.cc
+++ b/tensorflow/core/framework/shape_inference.cc
@@ -239,8 +239,11 @@ string InferenceContext::DebugString() const {
ProtoDebugString(node_def_));
}
-Status InferenceContext::WithRank(ShapeHandle shape, int32 rank,
+Status InferenceContext::WithRank(ShapeHandle shape, int64 rank,
ShapeHandle* out) {
+ if (rank > kint32max) {
+ return errors::InvalidArgument("Rank cannot exceed kint32max");
+ }
const int32 existing = Rank(shape);
if (existing == rank) {
*out = shape;
@@ -261,8 +264,11 @@ Status InferenceContext::WithRank(ShapeHandle shape, int32 rank,
existing);
}
-Status InferenceContext::WithRankAtLeast(ShapeHandle shape, int32 rank,
+Status InferenceContext::WithRankAtLeast(ShapeHandle shape, int64 rank,
ShapeHandle* out) {
+ if (rank > kint32max) {
+ return errors::InvalidArgument("Rank cannot exceed kint32max");
+ }
const int32 existing = Rank(shape);
if (existing >= rank) {
*out = shape;
@@ -276,8 +282,11 @@ Status InferenceContext::WithRankAtLeast(ShapeHandle shape, int32 rank,
" but is rank ", existing);
}
-Status InferenceContext::WithRankAtMost(ShapeHandle shape, int32 rank,
+Status InferenceContext::WithRankAtMost(ShapeHandle shape, int64 rank,
ShapeHandle* out) {
+ if (rank > kint32max) {
+ return errors::InvalidArgument("Rank cannot exceed kint32max");
+ }
const int32 existing = Rank(shape);
if (existing == kUnknownRank) {
return ReturnUnknownShape(out);
@@ -470,12 +479,12 @@ Status InferenceContext::Concatenate(ShapeHandle s1, ShapeHandle s2,
return ReturnCreatedShape(dims, out);
}
-Status InferenceContext::ReplaceDim(ShapeHandle s, int dim_index_in,
+Status InferenceContext::ReplaceDim(ShapeHandle s, int64 dim_index_in,
DimensionHandle new_dim, ShapeHandle* out) {
if (!RankKnown(s)) {
return ReturnUnknownShape(out);
}
- int dim_index = dim_index_in;
+ int64 dim_index = dim_index_in;
if (dim_index < 0) {
dim_index = s->dims_.size() + dim_index;
}
@@ -510,7 +519,8 @@ ShapeHandle InferenceContext::UnknownShape() {
return shape_manager_.UnknownShape();
}
-ShapeHandle InferenceContext::UnknownShapeOfRank(int32 rank) {
+ShapeHandle InferenceContext::UnknownShapeOfRank(int64 rank) {
+ CHECK_LE(rank, kint32max) << "rank must be less than kint32max";
std::vector<DimensionHandle> dims(rank);
for (int32 i = 0; i < rank; ++i) {
dims[i] = UnknownDim();