aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/java/src/main/java/org/tensorflow/op/core/Zeros.java
diff options
context:
space:
mode:
Diffstat (limited to 'tensorflow/java/src/main/java/org/tensorflow/op/core/Zeros.java')
-rw-r--r--tensorflow/java/src/main/java/org/tensorflow/op/core/Zeros.java56
1 files changed, 56 insertions, 0 deletions
diff --git a/tensorflow/java/src/main/java/org/tensorflow/op/core/Zeros.java b/tensorflow/java/src/main/java/org/tensorflow/op/core/Zeros.java
new file mode 100644
index 0000000000..7dd35bb21f
--- /dev/null
+++ b/tensorflow/java/src/main/java/org/tensorflow/op/core/Zeros.java
@@ -0,0 +1,56 @@
+package org.tensorflow.op.core;
+
+import java.nio.ByteBuffer;
+
+import org.tensorflow.DataType;
+import org.tensorflow.Operand;
+import org.tensorflow.Output;
+import org.tensorflow.Shape;
+import org.tensorflow.op.Op;
+import org.tensorflow.op.Scope;
+import org.tensorflow.op.annotation.Operator;
+
+/**
+ * An operator creating a constant initialized with zeros w.r.t its type and shape.
+ *
+ * @param <T> constant type
+ */
+@Operator
+public class Zeros<T> implements Op, Operand<T> {
+
+ /**
+ * Factory method for this operator
+ *
+ * @param scope is a scope used to add the underlying operation.
+ * @param type the tensor datatype.
+ * @param shape the tensor shape.
+ * @return a constant initialized with zeros
+ * @throws IllegalArgumentException if the tensor type or shape cannot be initialized with zeros.
+ */
+ public static <T> Zeros<T> create(Scope scope, Class<T> type, Shape shape) {
+ int numElements = (int) shape.numElements();
+ if (numElements < 0) {
+ throw new IllegalArgumentException("Only shapes with known dimension sizes can be used with zeroed constants");
+ }
+ int sizeInBytes = DataType.fromClass(type).sizeInBytes();
+ if (sizeInBytes < 0) {
+ throw new IllegalArgumentException(type.getSimpleName() + " constants cannot be initialized with zeros");
+ }
+ return new Zeros<T>(Constant.create(scope, type, shape, ByteBuffer.allocate(numElements * sizeInBytes)));
+ }
+
+ @Override
+ public Output<T> asOutput() {
+ return constant.asOutput();
+ }
+
+ public Constant<T> constant() {
+ return constant;
+ }
+
+ private final Constant<T> constant;
+
+ private Zeros(Constant<T> constant) {
+ this.constant = constant;
+ }
+}