aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/java/src/main/java/org/tensorflow/op/core/Zeros.java
diff options
context:
space:
mode:
Diffstat (limited to 'tensorflow/java/src/main/java/org/tensorflow/op/core/Zeros.java')
-rw-r--r--tensorflow/java/src/main/java/org/tensorflow/op/core/Zeros.java46
1 files changed, 22 insertions, 24 deletions
diff --git a/tensorflow/java/src/main/java/org/tensorflow/op/core/Zeros.java b/tensorflow/java/src/main/java/org/tensorflow/op/core/Zeros.java
index 5bba594e17..cc46ce3c5b 100644
--- a/tensorflow/java/src/main/java/org/tensorflow/op/core/Zeros.java
+++ b/tensorflow/java/src/main/java/org/tensorflow/op/core/Zeros.java
@@ -5,13 +5,17 @@ import java.nio.ByteBuffer;
import org.tensorflow.DataType;
import org.tensorflow.Operand;
import org.tensorflow.Output;
-import org.tensorflow.Shape;
import org.tensorflow.op.Op;
import org.tensorflow.op.Scope;
import org.tensorflow.op.annotation.Operator;
/**
- * An operator creating a constant initialized with zeros w.r.t its type and shape.
+ * An operator creating a constant initialized with zeros of the shape given by `dims`.
+ *
+ * <p>For example, the following expression
+ * <pre>{@code ops.zeros(ops.constant(new long[]{2, 2}), Float.class)</pre>
+ * is the equivalent of
+ * <pre>{@code ops.fill(ops.constant(new long[]{2, 2}), ops.constant(0.0f))</pre>
*
* @param <T> constant type
*/
@@ -19,38 +23,32 @@ import org.tensorflow.op.annotation.Operator;
public class Zeros<T> implements Op, Operand<T> {
/**
- * Factory method for this operator
+ * Creates a zeroed tensor given its type and shape.
*
- * @param scope is a scope used to add the underlying operation.
- * @param type the tensor datatype.
- * @param shape the tensor shape.
- * @return a constant initialized with zeros
+ * @param scope is a scope used to add the underlying operation
+ * @param dims a 1-D operand that represents the shape of the output tensor
+ * @param type the output tensor datatype
+ * @return a constant tensor initialized with zeros
* @throws IllegalArgumentException if the tensor type or shape cannot be initialized with zeros.
*/
- public static <T> Zeros<T> create(Scope scope, Class<T> type, Shape shape) {
- int numElements = shape.numElements();
- if (numElements < 0) {
- throw new IllegalArgumentException("Only shapes with known dimension sizes can be used with zeroed constants");
+ public static <T, U extends Number> Zeros<T> create(Scope scope, Operand<U> dims, Class<T> type) {
+ Scope childScope = scope.withSubScope("Zeros"); // If scope had an op name set, it will prevail on "Zeros"
+ int zeroSize = DataType.fromClass(type).byteSize();
+ if (zeroSize < 0) {
+ throw new IllegalArgumentException(type.getSimpleName() + " tensors cannot be initialized with zeros");
}
- int sizeInBytes = DataType.fromClass(type).sizeInBytes();
- if (sizeInBytes < 0) {
- throw new IllegalArgumentException(type.getSimpleName() + " constants cannot be initialized with zeros");
- }
- return new Zeros<T>(Constant.create(scope, type, shape, ByteBuffer.allocate(numElements * sizeInBytes)));
+ Constant<T> zero = Constant.create(childScope.withName("Zero"), type, new long[]{}, ByteBuffer.allocate(zeroSize));
+ return new Zeros<T>(Fill.create(childScope, dims, zero));
}
@Override
public Output<T> asOutput() {
- return constant.asOutput();
- }
-
- public Constant<T> constant() {
- return constant;
+ return fill.asOutput();
}
- private final Constant<T> constant;
+ private final Fill<T> fill;
- private Zeros(Constant<T> constant) {
- this.constant = constant;
+ private Zeros(Fill<T> fill) {
+ this.fill = fill;
}
}