aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/java/src/main/java/org/tensorflow/Graph.java
diff options
context:
space:
mode:
Diffstat (limited to 'tensorflow/java/src/main/java/org/tensorflow/Graph.java')
-rw-r--r--tensorflow/java/src/main/java/org/tensorflow/Graph.java63
1 files changed, 63 insertions, 0 deletions
diff --git a/tensorflow/java/src/main/java/org/tensorflow/Graph.java b/tensorflow/java/src/main/java/org/tensorflow/Graph.java
index c08fa9b145..58ad3ab193 100644
--- a/tensorflow/java/src/main/java/org/tensorflow/Graph.java
+++ b/tensorflow/java/src/main/java/org/tensorflow/Graph.java
@@ -15,6 +15,8 @@ limitations under the License.
package org.tensorflow;
+import java.util.Iterator;
+
/**
* A data flow graph representing a TensorFlow computation.
*
@@ -77,6 +79,16 @@ public final class Graph implements AutoCloseable {
}
/**
+ * Iterator over all the {@link Operation}s in the graph.
+ *
+ * The order of iteration is unspecified. Consumers of the iterator will received no notification
+ * should the underlying graph change during iteration.
+ */
+ public Iterator<Operation> operations() {
+ return new OperationIterator(this);
+ }
+
+ /**
* Returns a builder to add {@link Operation}s to the Graph.
*
* @param type of the Operation (i.e., identifies the computation to be performed)
@@ -179,12 +191,63 @@ public final class Graph implements AutoCloseable {
return new Reference();
}
+ private static final class OperationIterator implements Iterator<Operation> {
+
+ OperationIterator(Graph g) {
+ this.graph = g;
+ this.operation = null;
+ this.position = 0;
+ this.advance();
+ }
+
+ private final void advance() {
+ Graph.Reference reference = this.graph.ref();
+
+ this.operation = null;
+
+ try {
+ long[] nativeReturn = nextOperation(reference.nativeHandle(), this.position);
+
+ if ((nativeReturn != null) && (nativeReturn[0] != 0)) {
+ this.operation = new Operation(this.graph, nativeReturn[0]);
+ this.position = (int) nativeReturn[1];
+ }
+ } finally {
+ reference.close();
+ }
+ }
+
+ @Override
+ public boolean hasNext() {
+ return (this.operation != null);
+ }
+
+ @Override
+ public Operation next() {
+ Operation rhett = this.operation;
+ this.advance();
+ return rhett;
+ }
+
+ @Override
+ public void remove() {
+ throw new UnsupportedOperationException("remove() is unsupported.");
+ }
+
+ private final Graph graph;
+ private Operation operation;
+ private int position;
+ }
+
private static native long allocate();
private static native void delete(long handle);
private static native long operation(long handle, String name);
+ // This method returns the Operation native handle at index 0 and the new value for pos at index 1 (see TF_GraphNextOperation)
+ private static native long[] nextOperation(long handle, int position);
+
private static native void importGraphDef(long handle, byte[] graphDef, String prefix)
throws IllegalArgumentException;