aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/java/src/test/java/org/tensorflow/OperationTest.java
diff options
context:
space:
mode:
Diffstat (limited to 'tensorflow/java/src/test/java/org/tensorflow/OperationTest.java')
-rw-r--r--tensorflow/java/src/test/java/org/tensorflow/OperationTest.java112
1 files changed, 112 insertions, 0 deletions
diff --git a/tensorflow/java/src/test/java/org/tensorflow/OperationTest.java b/tensorflow/java/src/test/java/org/tensorflow/OperationTest.java
index 74fdcf484e..27afc046ac 100644
--- a/tensorflow/java/src/test/java/org/tensorflow/OperationTest.java
+++ b/tensorflow/java/src/test/java/org/tensorflow/OperationTest.java
@@ -16,8 +16,14 @@ limitations under the License.
package org.tensorflow;
import static org.junit.Assert.assertEquals;
+import static org.junit.Assert.assertNotEquals;
+import static org.junit.Assert.assertNotNull;
+import static org.junit.Assert.assertTrue;
import static org.junit.Assert.fail;
+import java.util.Arrays;
+import java.util.HashSet;
+import java.util.Set;
import org.junit.Test;
import org.junit.runner.RunWith;
import org.junit.runners.JUnit4;
@@ -46,12 +52,107 @@ public class OperationTest {
}
@Test
+ public void operationEquality() {
+ Operation op1;
+ try (Graph g = new Graph()) {
+ op1 = TestUtil.constant(g, "op1", 1).op();
+ Operation op2 = TestUtil.constant(g, "op2", 2).op();
+ Operation op3 = new Operation(g, op1.getUnsafeNativeHandle());
+ Operation op4 = g.operation("op1");
+ assertEquals(op1, op1);
+ assertNotEquals(op1, op2);
+ assertEquals(op1, op3);
+ assertEquals(op1.hashCode(), op3.hashCode());
+ assertEquals(op1, op4);
+ assertEquals(op1.hashCode(), op4.hashCode());
+ assertEquals(op3, op4);
+ assertNotEquals(op2, op3);
+ assertNotEquals(op2, op4);
+ }
+ try (Graph g = new Graph()) {
+ Operation newOp1 = TestUtil.constant(g, "op1", 1).op();
+ assertNotEquals(op1, newOp1);
+ }
+ }
+
+ @Test
+ public void operationCollection() {
+ try (Graph g = new Graph()) {
+ Operation op1 = TestUtil.constant(g, "op1", 1).op();
+ Operation op2 = TestUtil.constant(g, "op2", 2).op();
+ Operation op3 = new Operation(g, op1.getUnsafeNativeHandle());
+ Operation op4 = g.operation("op1");
+ Set<Operation> ops = new HashSet<>();
+ ops.addAll(Arrays.asList(op1, op2, op3, op4));
+ assertEquals(2, ops.size());
+ assertTrue(ops.contains(op1));
+ assertTrue(ops.contains(op2));
+ assertTrue(ops.contains(op3));
+ assertTrue(ops.contains(op4));
+ }
+ }
+
+ @Test
+ public void operationToString() {
+ try (Graph g = new Graph()) {
+ Operation op = TestUtil.constant(g, "c", new int[] {1}).op();
+ assertNotNull(op.toString());
+ }
+ }
+
+ @Test
+ public void outputEquality() {
+ try (Graph g = new Graph()) {
+ Output output = TestUtil.constant(g, "c", 1);
+ Output output1 = output.op().output(0);
+ Output output2 = g.operation("c").output(0);
+ assertEquals(output, output1);
+ assertEquals(output.hashCode(), output1.hashCode());
+ assertEquals(output, output2);
+ assertEquals(output.hashCode(), output2.hashCode());
+ }
+ }
+
+ @Test
+ public void outputCollection() {
+ try (Graph g = new Graph()) {
+ Output output = TestUtil.constant(g, "c", 1);
+ Output output1 = output.op().output(0);
+ Output output2 = g.operation("c").output(0);
+ Set<Output> ops = new HashSet<>();
+ ops.addAll(Arrays.asList(output, output1, output2));
+ assertEquals(1, ops.size());
+ assertTrue(ops.contains(output));
+ assertTrue(ops.contains(output1));
+ assertTrue(ops.contains(output2));
+ }
+ }
+
+ @Test
+ public void outputToString() {
+ try (Graph g = new Graph()) {
+ Output output = TestUtil.constant(g, "c", new int[] {1});
+ assertNotNull(output.toString());
+ }
+ }
+
+ @Test
public void outputListLength() {
assertEquals(1, split(new int[] {0, 1}, 1));
assertEquals(2, split(new int[] {0, 1}, 2));
assertEquals(3, split(new int[] {0, 1, 2}, 3));
}
+ @Test
+ public void inputListLength() {
+ assertEquals(1, splitWithInputList(new int[] {0, 1}, 1, "split_dim"));
+ try {
+ splitWithInputList(new int[] {0, 1}, 2, "inputs");
+ } catch (IllegalArgumentException iae) {
+ // expected
+ }
+ }
+
private static int split(int[] values, int num_split) {
try (Graph g = new Graph()) {
return g.opBuilder("Split", "Split")
@@ -62,4 +163,15 @@ public class OperationTest {
.outputListLength("output");
}
}
+
+ private static int splitWithInputList(int[] values, int num_split, String name) {
+ try (Graph g = new Graph()) {
+ return g.opBuilder("Split", "Split")
+ .addInput(TestUtil.constant(g, "split_dim", 0))
+ .addInput(TestUtil.constant(g, "values", values))
+ .setAttr("num_split", num_split)
+ .build()
+ .inputListLength(name);
+ }
+ }
}