aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/core/framework/shape_inference.cc
diff options
context:
space:
mode:
authorGravatar Benoit Steiner <bsteiner@google.com>2018-03-26 11:38:15 -0700
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2018-03-26 11:41:03 -0700
commit73937a7096908a9ae01dd7da2d76932a7fed194b (patch)
tree197343b4c9a4bec483a120203a320c5a675cd048 /tensorflow/core/framework/shape_inference.cc
parenta7588a70a5de8ece6920f4eb8b877104ede898f7 (diff)
Made the NumElements function more accurate
PiperOrigin-RevId: 190497916
Diffstat (limited to 'tensorflow/core/framework/shape_inference.cc')
-rw-r--r--tensorflow/core/framework/shape_inference.cc16
1 files changed, 13 insertions, 3 deletions
diff --git a/tensorflow/core/framework/shape_inference.cc b/tensorflow/core/framework/shape_inference.cc
index 641681973a..54ecaa5dd4 100644
--- a/tensorflow/core/framework/shape_inference.cc
+++ b/tensorflow/core/framework/shape_inference.cc
@@ -298,13 +298,23 @@ bool InferenceContext::FullyDefined(ShapeHandle s) {
DimensionHandle InferenceContext::NumElements(ShapeHandle s) {
const auto rank = Rank(s);
if (rank == kUnknownRank) return UnknownDim();
+ bool found_unknown = false;
int64 size = 1;
for (int i = 0; i < rank; ++i) {
int64 dim_val = Value(Dim(s, i));
- if (dim_val == kUnknownDim) return UnknownDim();
- size *= dim_val;
+ if (dim_val == kUnknownDim) {
+ found_unknown = true;
+ } else if (dim_val == 0) {
+ return MakeDim(0);
+ } else {
+ size *= dim_val;
+ }
+ }
+ if (found_unknown) {
+ return UnknownDim();
+ } else {
+ return MakeDim(size);
}
- return MakeDim(size);
}
string InferenceContext::DebugString(ShapeHandle s) {