diff options
Diffstat (limited to 'tensorflow/java/src/main/java/org/tensorflow/Operation.java')
-rw-r--r-- | tensorflow/java/src/main/java/org/tensorflow/Operation.java | 26 |
1 files changed, 25 insertions, 1 deletions
diff --git a/tensorflow/java/src/main/java/org/tensorflow/Operation.java b/tensorflow/java/src/main/java/org/tensorflow/Operation.java index 43dbaf125c..e7de603409 100644 --- a/tensorflow/java/src/main/java/org/tensorflow/Operation.java +++ b/tensorflow/java/src/main/java/org/tensorflow/Operation.java @@ -79,7 +79,7 @@ public final class Operation { * * @param name identifier of the list of tensors (of which there may * be many) produced by this operation. - * @returns the size of the list of Tensors produced by this named output. + * @return the size of the list of Tensors produced by this named output. * @throws IllegalArgumentException if this operation has no output * with the provided name. */ @@ -97,6 +97,28 @@ public final class Operation { return new Output(this, idx); } + /** + * Returns the size of the given inputs list of Tensors for this operation. + * + * <p>An Operation has multiple named inputs, each of which contains either + * a single tensor or a list of tensors. This method returns the size of + * the list of tensors for a specific named input of the operation. + * + * @param name identifier of the list of tensors (of which there may + * be many) inputs to this operation. + * @returns the size of the list of Tensors produced by this named input. + * @throws IllegalArgumentException if this operation has no input + * with the provided name. + */ + public int inputListLength(final String name) { + Graph.Reference r = graph.ref(); + try { + return inputListLength(unsafeNativeHandle, name); + } finally { + r.close(); + } + } + long getUnsafeNativeHandle() { return unsafeNativeHandle; } @@ -132,6 +154,8 @@ public final class Operation { private static native int outputListLength(long handle, String name); + private static native int inputListLength(long handle, String name); + private static native long[] shape(long graphHandle, long opHandle, int output); private static native int dtype(long graphHandle, long opHandle, int output); |