diff options
Diffstat (limited to 'tensorflow/java/src/main/java/org/tensorflow/Graph.java')
-rw-r--r-- | tensorflow/java/src/main/java/org/tensorflow/Graph.java | 63 |
1 files changed, 63 insertions, 0 deletions
diff --git a/tensorflow/java/src/main/java/org/tensorflow/Graph.java b/tensorflow/java/src/main/java/org/tensorflow/Graph.java index c08fa9b145..58ad3ab193 100644 --- a/tensorflow/java/src/main/java/org/tensorflow/Graph.java +++ b/tensorflow/java/src/main/java/org/tensorflow/Graph.java @@ -15,6 +15,8 @@ limitations under the License. package org.tensorflow; +import java.util.Iterator; + /** * A data flow graph representing a TensorFlow computation. * @@ -77,6 +79,16 @@ public final class Graph implements AutoCloseable { } /** + * Iterator over all the {@link Operation}s in the graph. + * + * The order of iteration is unspecified. Consumers of the iterator will received no notification + * should the underlying graph change during iteration. + */ + public Iterator<Operation> operations() { + return new OperationIterator(this); + } + + /** * Returns a builder to add {@link Operation}s to the Graph. * * @param type of the Operation (i.e., identifies the computation to be performed) @@ -179,12 +191,63 @@ public final class Graph implements AutoCloseable { return new Reference(); } + private static final class OperationIterator implements Iterator<Operation> { + + OperationIterator(Graph g) { + this.graph = g; + this.operation = null; + this.position = 0; + this.advance(); + } + + private final void advance() { + Graph.Reference reference = this.graph.ref(); + + this.operation = null; + + try { + long[] nativeReturn = nextOperation(reference.nativeHandle(), this.position); + + if ((nativeReturn != null) && (nativeReturn[0] != 0)) { + this.operation = new Operation(this.graph, nativeReturn[0]); + this.position = (int) nativeReturn[1]; + } + } finally { + reference.close(); + } + } + + @Override + public boolean hasNext() { + return (this.operation != null); + } + + @Override + public Operation next() { + Operation rhett = this.operation; + this.advance(); + return rhett; + } + + @Override + public void remove() { + throw new UnsupportedOperationException("remove() is unsupported."); + } + + private final Graph graph; + private Operation operation; + private int position; + } + private static native long allocate(); private static native void delete(long handle); private static native long operation(long handle, String name); + // This method returns the Operation native handle at index 0 and the new value for pos at index 1 (see TF_GraphNextOperation) + private static native long[] nextOperation(long handle, int position); + private static native void importGraphDef(long handle, byte[] graphDef, String prefix) throws IllegalArgumentException; |