aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/java/src/main/java/org/tensorflow/Operation.java
diff options
context:
space:
mode:
Diffstat (limited to 'tensorflow/java/src/main/java/org/tensorflow/Operation.java')
-rw-r--r--tensorflow/java/src/main/java/org/tensorflow/Operation.java26
1 files changed, 25 insertions, 1 deletions
diff --git a/tensorflow/java/src/main/java/org/tensorflow/Operation.java b/tensorflow/java/src/main/java/org/tensorflow/Operation.java
index 43dbaf125c..e7de603409 100644
--- a/tensorflow/java/src/main/java/org/tensorflow/Operation.java
+++ b/tensorflow/java/src/main/java/org/tensorflow/Operation.java
@@ -79,7 +79,7 @@ public final class Operation {
*
* @param name identifier of the list of tensors (of which there may
* be many) produced by this operation.
- * @returns the size of the list of Tensors produced by this named output.
+ * @return the size of the list of Tensors produced by this named output.
* @throws IllegalArgumentException if this operation has no output
* with the provided name.
*/
@@ -97,6 +97,28 @@ public final class Operation {
return new Output(this, idx);
}
+ /**
+ * Returns the size of the given inputs list of Tensors for this operation.
+ *
+ * <p>An Operation has multiple named inputs, each of which contains either
+ * a single tensor or a list of tensors. This method returns the size of
+ * the list of tensors for a specific named input of the operation.
+ *
+ * @param name identifier of the list of tensors (of which there may
+ * be many) inputs to this operation.
+ * @returns the size of the list of Tensors produced by this named input.
+ * @throws IllegalArgumentException if this operation has no input
+ * with the provided name.
+ */
+ public int inputListLength(final String name) {
+ Graph.Reference r = graph.ref();
+ try {
+ return inputListLength(unsafeNativeHandle, name);
+ } finally {
+ r.close();
+ }
+ }
+
long getUnsafeNativeHandle() {
return unsafeNativeHandle;
}
@@ -132,6 +154,8 @@ public final class Operation {
private static native int outputListLength(long handle, String name);
+ private static native int inputListLength(long handle, String name);
+
private static native long[] shape(long graphHandle, long opHandle, int output);
private static native int dtype(long graphHandle, long opHandle, int output);