aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/java/src/main/java/org/tensorflow/Shape.java
diff options
context:
space:
mode:
Diffstat (limited to 'tensorflow/java/src/main/java/org/tensorflow/Shape.java')
-rw-r--r--tensorflow/java/src/main/java/org/tensorflow/Shape.java32
1 files changed, 32 insertions, 0 deletions
diff --git a/tensorflow/java/src/main/java/org/tensorflow/Shape.java b/tensorflow/java/src/main/java/org/tensorflow/Shape.java
index 9aa92be111..d533c3d480 100644
--- a/tensorflow/java/src/main/java/org/tensorflow/Shape.java
+++ b/tensorflow/java/src/main/java/org/tensorflow/Shape.java
@@ -77,6 +77,24 @@ public final class Shape {
return shape[i];
}
+ @Override
+ public int hashCode() {
+ return Arrays.hashCode(shape);
+ }
+
+ @Override
+ public boolean equals(Object obj) {
+ if (this == obj) {
+ return true;
+ }
+
+ if (obj instanceof Shape && Arrays.equals(this.shape, ((Shape) obj).shape)) {
+ return !hasUnknownDimension();
+ }
+
+ return super.equals(obj);
+ }
+
/** Succinct description of the shape meant for debugging. */
@Override
public String toString() {
@@ -98,4 +116,18 @@ public final class Shape {
}
private long[] shape;
+
+ private boolean hasUnknownDimension() {
+ if (shape == null) {
+ return true;
+ }
+
+ for (long dimension : shape) {
+ if (dimension == -1) {
+ return true;
+ }
+ }
+
+ return false;
+ }
}