diff options
Diffstat (limited to 'tensorflow/core/framework/shape_inference.cc')
-rw-r--r-- | tensorflow/core/framework/shape_inference.cc | 22 |
1 files changed, 16 insertions, 6 deletions
diff --git a/tensorflow/core/framework/shape_inference.cc b/tensorflow/core/framework/shape_inference.cc index 449d8f55f5..a990dc2f04 100644 --- a/tensorflow/core/framework/shape_inference.cc +++ b/tensorflow/core/framework/shape_inference.cc @@ -239,8 +239,11 @@ string InferenceContext::DebugString() const { ProtoDebugString(node_def_)); } -Status InferenceContext::WithRank(ShapeHandle shape, int32 rank, +Status InferenceContext::WithRank(ShapeHandle shape, int64 rank, ShapeHandle* out) { + if (rank > kint32max) { + return errors::InvalidArgument("Rank cannot exceed kint32max"); + } const int32 existing = Rank(shape); if (existing == rank) { *out = shape; @@ -261,8 +264,11 @@ Status InferenceContext::WithRank(ShapeHandle shape, int32 rank, existing); } -Status InferenceContext::WithRankAtLeast(ShapeHandle shape, int32 rank, +Status InferenceContext::WithRankAtLeast(ShapeHandle shape, int64 rank, ShapeHandle* out) { + if (rank > kint32max) { + return errors::InvalidArgument("Rank cannot exceed kint32max"); + } const int32 existing = Rank(shape); if (existing >= rank) { *out = shape; @@ -276,8 +282,11 @@ Status InferenceContext::WithRankAtLeast(ShapeHandle shape, int32 rank, " but is rank ", existing); } -Status InferenceContext::WithRankAtMost(ShapeHandle shape, int32 rank, +Status InferenceContext::WithRankAtMost(ShapeHandle shape, int64 rank, ShapeHandle* out) { + if (rank > kint32max) { + return errors::InvalidArgument("Rank cannot exceed kint32max"); + } const int32 existing = Rank(shape); if (existing == kUnknownRank) { return ReturnUnknownShape(out); @@ -470,12 +479,12 @@ Status InferenceContext::Concatenate(ShapeHandle s1, ShapeHandle s2, return ReturnCreatedShape(dims, out); } -Status InferenceContext::ReplaceDim(ShapeHandle s, int dim_index_in, +Status InferenceContext::ReplaceDim(ShapeHandle s, int64 dim_index_in, DimensionHandle new_dim, ShapeHandle* out) { if (!RankKnown(s)) { return ReturnUnknownShape(out); } - int dim_index = dim_index_in; + int64 dim_index = dim_index_in; if (dim_index < 0) { dim_index = s->dims_.size() + dim_index; } @@ -510,7 +519,8 @@ ShapeHandle InferenceContext::UnknownShape() { return shape_manager_.UnknownShape(); } -ShapeHandle InferenceContext::UnknownShapeOfRank(int32 rank) { +ShapeHandle InferenceContext::UnknownShapeOfRank(int64 rank) { + CHECK_LE(rank, kint32max) << "rank must be less than kint32max"; std::vector<DimensionHandle> dims(rank); for (int32 i = 0; i < rank; ++i) { dims[i] = UnknownDim(); |