diff options
Diffstat (limited to 'tensorflow/java/src/main/java/org/tensorflow/Operation.java')
-rw-r--r-- | tensorflow/java/src/main/java/org/tensorflow/Operation.java | 73 |
1 files changed, 65 insertions, 8 deletions
diff --git a/tensorflow/java/src/main/java/org/tensorflow/Operation.java b/tensorflow/java/src/main/java/org/tensorflow/Operation.java index 43dbaf125c..200350f7ae 100644 --- a/tensorflow/java/src/main/java/org/tensorflow/Operation.java +++ b/tensorflow/java/src/main/java/org/tensorflow/Operation.java @@ -73,15 +73,14 @@ public final class Operation { /** * Returns the size of the list of Tensors produced by this operation. * - * <p>An Operation has multiple named outputs, each of which produces either - * a single tensor or a list of tensors. This method returns the size of - * the list of tensors for a specific named output of the operation. + * <p>An Operation has multiple named outputs, each of which produces either a single tensor or a + * list of tensors. This method returns the size of the list of tensors for a specific named + * output of the operation. * - * @param name identifier of the list of tensors (of which there may - * be many) produced by this operation. - * @returns the size of the list of Tensors produced by this named output. - * @throws IllegalArgumentException if this operation has no output - * with the provided name. + * @param name identifier of the list of tensors (of which there may be many) produced by this + * operation. + * @return the size of the list of Tensors produced by this named output. + * @throws IllegalArgumentException if this operation has no output with the provided name. */ public int outputListLength(final String name) { Graph.Reference r = graph.ref(); @@ -97,6 +96,61 @@ public final class Operation { return new Output(this, idx); } + @Override + public int hashCode() { + return Long.valueOf(unsafeNativeHandle).hashCode(); + } + + @Override + public boolean equals(Object o) { + if (o == this) { + return true; + } + if (!(o instanceof Operation)) { + return false; + } + Operation that = (Operation) o; + if (graph != that.graph) { + return false; + } + + // The graph object is known to be identical here, so this one + // reference is sufficient to validate the use of native pointers + // in both objects. + Graph.Reference r = graph.ref(); + try { + return unsafeNativeHandle == that.unsafeNativeHandle; + } finally { + r.close(); + } + } + + @Override + public String toString() { + return String.format("<%s '%s'>", type(), name()); + } + + /** + * Returns the size of the given inputs list of Tensors for this operation. + * + * <p>An Operation has multiple named inputs, each of which contains either a single tensor or a + * list of tensors. This method returns the size of the list of tensors for a specific named input + * of the operation. + * + * @param name identifier of the list of tensors (of which there may be many) inputs to this + * operation. + * @returns the size of the list of Tensors produced by this named input. + * @throws IllegalArgumentException if this operation has no input with the provided name. + */ + public int inputListLength(final String name) { + Graph.Reference r = graph.ref(); + try { + return inputListLength(unsafeNativeHandle, name); + } finally { + r.close(); + } + } + long getUnsafeNativeHandle() { return unsafeNativeHandle; } @@ -122,6 +176,7 @@ public final class Operation { } private final long unsafeNativeHandle; + private final Graph graph; private static native String name(long handle); @@ -132,6 +187,8 @@ public final class Operation { private static native int outputListLength(long handle, String name); + private static native int inputListLength(long handle, String name); + private static native long[] shape(long graphHandle, long opHandle, int output); private static native int dtype(long graphHandle, long opHandle, int output); |