diff options
Diffstat (limited to 'tensorflow/java/src/main/java/org/tensorflow/Shape.java')
-rw-r--r-- | tensorflow/java/src/main/java/org/tensorflow/Shape.java | 22 |
1 files changed, 22 insertions, 0 deletions
diff --git a/tensorflow/java/src/main/java/org/tensorflow/Shape.java b/tensorflow/java/src/main/java/org/tensorflow/Shape.java index d533c3d480..1662a49cb7 100644 --- a/tensorflow/java/src/main/java/org/tensorflow/Shape.java +++ b/tensorflow/java/src/main/java/org/tensorflow/Shape.java @@ -77,6 +77,28 @@ public final class Shape { return shape[i]; } + /** + * The total number of elements found in a tensor of this shape. + * + * <p>If the size of some dimensions is unknown, the total number of elements cannot be calculated and -1 is returned. + * + * @return the number of elements or -1 if size of some dimension are unknown + */ + public int numElements() { + if (shape == null) { + return -1; + } + long total = 1; + for (int i = 0; i < shape.length; ++i) { + long size = size(i); + if (size < 0) { + return -1; + } + total *= size; + } + return total; + } + @Override public int hashCode() { return Arrays.hashCode(shape); |