aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/java/src/test/java/org/tensorflow/ShapeTest.java
diff options
context:
space:
mode:
Diffstat (limited to 'tensorflow/java/src/test/java/org/tensorflow/ShapeTest.java')
-rw-r--r--tensorflow/java/src/test/java/org/tensorflow/ShapeTest.java26
1 files changed, 26 insertions, 0 deletions
diff --git a/tensorflow/java/src/test/java/org/tensorflow/ShapeTest.java b/tensorflow/java/src/test/java/org/tensorflow/ShapeTest.java
index 3b027700c5..92cc3bd60e 100644
--- a/tensorflow/java/src/test/java/org/tensorflow/ShapeTest.java
+++ b/tensorflow/java/src/test/java/org/tensorflow/ShapeTest.java
@@ -16,6 +16,7 @@ limitations under the License.
package org.tensorflow;
import static org.junit.Assert.assertEquals;
+import static org.junit.Assert.assertNotEquals;
import org.junit.Test;
import org.junit.runner.RunWith;
@@ -77,4 +78,29 @@ public class ShapeTest {
assertEquals(5, n.shape().size(1));
}
}
+
+ @Test
+ public void equalsWorksCorrectly() {
+ assertEquals(Shape.scalar(), Shape.scalar());
+ assertEquals(Shape.make(1, 2, 3), Shape.make(1, 2, 3));
+
+ assertNotEquals(Shape.make(1,2), null);
+ assertNotEquals(Shape.make(1,2), new Object());
+ assertNotEquals(Shape.make(1, 2, 3), Shape.make(1, 2, 4));
+
+
+ assertNotEquals(Shape.unknown(), Shape.unknown());
+ assertNotEquals(Shape.make(-1), Shape.make(-1));
+ assertNotEquals(Shape.make(1, -1, 3), Shape.make(1, -1, 3));
+ }
+
+ @Test
+ public void hashCodeIsAsExpected() {
+ assertEquals(Shape.make(1, 2, 3, 4).hashCode(), Shape.make(1, 2, 3, 4).hashCode());
+ assertEquals(Shape.scalar().hashCode(), Shape.scalar().hashCode());
+ assertEquals(Shape.unknown().hashCode(), Shape.unknown().hashCode());
+
+ assertNotEquals(Shape.make(1, 2).hashCode(), Shape.make(1, 3).hashCode());
+ }
}
+