aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/java/src/main/java/org/tensorflow/Shape.java
diff options
context:
space:
mode:
Diffstat (limited to 'tensorflow/java/src/main/java/org/tensorflow/Shape.java')
-rw-r--r--tensorflow/java/src/main/java/org/tensorflow/Shape.java22
1 files changed, 22 insertions, 0 deletions
diff --git a/tensorflow/java/src/main/java/org/tensorflow/Shape.java b/tensorflow/java/src/main/java/org/tensorflow/Shape.java
index d533c3d480..1662a49cb7 100644
--- a/tensorflow/java/src/main/java/org/tensorflow/Shape.java
+++ b/tensorflow/java/src/main/java/org/tensorflow/Shape.java
@@ -77,6 +77,28 @@ public final class Shape {
return shape[i];
}
+ /**
+ * The total number of elements found in a tensor of this shape.
+ *
+ * <p>If the size of some dimensions is unknown, the total number of elements cannot be calculated and -1 is returned.
+ *
+ * @return the number of elements or -1 if size of some dimension are unknown
+ */
+ public int numElements() {
+ if (shape == null) {
+ return -1;
+ }
+ long total = 1;
+ for (int i = 0; i < shape.length; ++i) {
+ long size = size(i);
+ if (size < 0) {
+ return -1;
+ }
+ total *= size;
+ }
+ return total;
+ }
+
@Override
public int hashCode() {
return Arrays.hashCode(shape);