diff options
Diffstat (limited to 'tensorflow/java/src/test/java/org/tensorflow/TestUtil.java')
-rw-r--r-- | tensorflow/java/src/test/java/org/tensorflow/TestUtil.java | 8 |
1 files changed, 8 insertions, 0 deletions
diff --git a/tensorflow/java/src/test/java/org/tensorflow/TestUtil.java b/tensorflow/java/src/test/java/org/tensorflow/TestUtil.java index 6a3a16c2e1..e3415a696d 100644 --- a/tensorflow/java/src/test/java/org/tensorflow/TestUtil.java +++ b/tensorflow/java/src/test/java/org/tensorflow/TestUtil.java @@ -48,6 +48,14 @@ public class TestUtil { .output(0); } + public static Operation split(Graph g, String name, int[] values, int numSplit) { + return g.opBuilder("Split", name) + .addInput(constant(g, "split_dim", 0)) + .addInput(constant(g, "values", values)) + .setAttr("num_split", numSplit) + .build(); + } + public static void transpose_A_times_X(Graph g, int[][] a) { matmul(g, "Y", constant(g, "A", a), placeholder(g, "X", DataType.INT32), true, false); } |