diff options
Diffstat (limited to 'tensorflow/java/src/main/java/org/tensorflow/Shape.java')
-rw-r--r-- | tensorflow/java/src/main/java/org/tensorflow/Shape.java | 32 |
1 files changed, 32 insertions, 0 deletions
diff --git a/tensorflow/java/src/main/java/org/tensorflow/Shape.java b/tensorflow/java/src/main/java/org/tensorflow/Shape.java index 9aa92be111..d533c3d480 100644 --- a/tensorflow/java/src/main/java/org/tensorflow/Shape.java +++ b/tensorflow/java/src/main/java/org/tensorflow/Shape.java @@ -77,6 +77,24 @@ public final class Shape { return shape[i]; } + @Override + public int hashCode() { + return Arrays.hashCode(shape); + } + + @Override + public boolean equals(Object obj) { + if (this == obj) { + return true; + } + + if (obj instanceof Shape && Arrays.equals(this.shape, ((Shape) obj).shape)) { + return !hasUnknownDimension(); + } + + return super.equals(obj); + } + /** Succinct description of the shape meant for debugging. */ @Override public String toString() { @@ -98,4 +116,18 @@ public final class Shape { } private long[] shape; + + private boolean hasUnknownDimension() { + if (shape == null) { + return true; + } + + for (long dimension : shape) { + if (dimension == -1) { + return true; + } + } + + return false; + } } |