diff options
author | 2018-03-26 11:38:15 -0700 | |
---|---|---|
committer | 2018-03-26 11:41:03 -0700 | |
commit | 73937a7096908a9ae01dd7da2d76932a7fed194b (patch) | |
tree | 197343b4c9a4bec483a120203a320c5a675cd048 /tensorflow/core/framework/shape_inference.cc | |
parent | a7588a70a5de8ece6920f4eb8b877104ede898f7 (diff) |
Made the NumElements function more accurate
PiperOrigin-RevId: 190497916
Diffstat (limited to 'tensorflow/core/framework/shape_inference.cc')
-rw-r--r-- | tensorflow/core/framework/shape_inference.cc | 16 |
1 files changed, 13 insertions, 3 deletions
diff --git a/tensorflow/core/framework/shape_inference.cc b/tensorflow/core/framework/shape_inference.cc index 641681973a..54ecaa5dd4 100644 --- a/tensorflow/core/framework/shape_inference.cc +++ b/tensorflow/core/framework/shape_inference.cc @@ -298,13 +298,23 @@ bool InferenceContext::FullyDefined(ShapeHandle s) { DimensionHandle InferenceContext::NumElements(ShapeHandle s) { const auto rank = Rank(s); if (rank == kUnknownRank) return UnknownDim(); + bool found_unknown = false; int64 size = 1; for (int i = 0; i < rank; ++i) { int64 dim_val = Value(Dim(s, i)); - if (dim_val == kUnknownDim) return UnknownDim(); - size *= dim_val; + if (dim_val == kUnknownDim) { + found_unknown = true; + } else if (dim_val == 0) { + return MakeDim(0); + } else { + size *= dim_val; + } + } + if (found_unknown) { + return UnknownDim(); + } else { + return MakeDim(size); } - return MakeDim(size); } string InferenceContext::DebugString(ShapeHandle s) { |