diff options
Diffstat (limited to 'tensorflow/java/src/test/java/org/tensorflow/ShapeTest.java')
-rw-r--r-- | tensorflow/java/src/test/java/org/tensorflow/ShapeTest.java | 26 |
1 files changed, 26 insertions, 0 deletions
diff --git a/tensorflow/java/src/test/java/org/tensorflow/ShapeTest.java b/tensorflow/java/src/test/java/org/tensorflow/ShapeTest.java index 3b027700c5..92cc3bd60e 100644 --- a/tensorflow/java/src/test/java/org/tensorflow/ShapeTest.java +++ b/tensorflow/java/src/test/java/org/tensorflow/ShapeTest.java @@ -16,6 +16,7 @@ limitations under the License. package org.tensorflow; import static org.junit.Assert.assertEquals; +import static org.junit.Assert.assertNotEquals; import org.junit.Test; import org.junit.runner.RunWith; @@ -77,4 +78,29 @@ public class ShapeTest { assertEquals(5, n.shape().size(1)); } } + + @Test + public void equalsWorksCorrectly() { + assertEquals(Shape.scalar(), Shape.scalar()); + assertEquals(Shape.make(1, 2, 3), Shape.make(1, 2, 3)); + + assertNotEquals(Shape.make(1,2), null); + assertNotEquals(Shape.make(1,2), new Object()); + assertNotEquals(Shape.make(1, 2, 3), Shape.make(1, 2, 4)); + + + assertNotEquals(Shape.unknown(), Shape.unknown()); + assertNotEquals(Shape.make(-1), Shape.make(-1)); + assertNotEquals(Shape.make(1, -1, 3), Shape.make(1, -1, 3)); + } + + @Test + public void hashCodeIsAsExpected() { + assertEquals(Shape.make(1, 2, 3, 4).hashCode(), Shape.make(1, 2, 3, 4).hashCode()); + assertEquals(Shape.scalar().hashCode(), Shape.scalar().hashCode()); + assertEquals(Shape.unknown().hashCode(), Shape.unknown().hashCode()); + + assertNotEquals(Shape.make(1, 2).hashCode(), Shape.make(1, 3).hashCode()); + } } + |