aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/java/src/test/java/org/tensorflow/TestUtil.java
diff options
context:
space:
mode:
Diffstat (limited to 'tensorflow/java/src/test/java/org/tensorflow/TestUtil.java')
-rw-r--r--tensorflow/java/src/test/java/org/tensorflow/TestUtil.java8
1 files changed, 8 insertions, 0 deletions
diff --git a/tensorflow/java/src/test/java/org/tensorflow/TestUtil.java b/tensorflow/java/src/test/java/org/tensorflow/TestUtil.java
index 6a3a16c2e1..e3415a696d 100644
--- a/tensorflow/java/src/test/java/org/tensorflow/TestUtil.java
+++ b/tensorflow/java/src/test/java/org/tensorflow/TestUtil.java
@@ -48,6 +48,14 @@ public class TestUtil {
.output(0);
}
+ public static Operation split(Graph g, String name, int[] values, int numSplit) {
+ return g.opBuilder("Split", name)
+ .addInput(constant(g, "split_dim", 0))
+ .addInput(constant(g, "values", values))
+ .setAttr("num_split", numSplit)
+ .build();
+ }
+
public static void transpose_A_times_X(Graph g, int[][] a) {
matmul(g, "Y", constant(g, "A", a), placeholder(g, "X", DataType.INT32), true, false);
}