diff options
Diffstat (limited to 'tensorflow/java/src/main/java/org/tensorflow/op/core/Zeros.java')
-rw-r--r-- | tensorflow/java/src/main/java/org/tensorflow/op/core/Zeros.java | 46 |
1 files changed, 22 insertions, 24 deletions
diff --git a/tensorflow/java/src/main/java/org/tensorflow/op/core/Zeros.java b/tensorflow/java/src/main/java/org/tensorflow/op/core/Zeros.java index 5bba594e17..cc46ce3c5b 100644 --- a/tensorflow/java/src/main/java/org/tensorflow/op/core/Zeros.java +++ b/tensorflow/java/src/main/java/org/tensorflow/op/core/Zeros.java @@ -5,13 +5,17 @@ import java.nio.ByteBuffer; import org.tensorflow.DataType; import org.tensorflow.Operand; import org.tensorflow.Output; -import org.tensorflow.Shape; import org.tensorflow.op.Op; import org.tensorflow.op.Scope; import org.tensorflow.op.annotation.Operator; /** - * An operator creating a constant initialized with zeros w.r.t its type and shape. + * An operator creating a constant initialized with zeros of the shape given by `dims`. + * + * <p>For example, the following expression + * <pre>{@code ops.zeros(ops.constant(new long[]{2, 2}), Float.class)</pre> + * is the equivalent of + * <pre>{@code ops.fill(ops.constant(new long[]{2, 2}), ops.constant(0.0f))</pre> * * @param <T> constant type */ @@ -19,38 +23,32 @@ import org.tensorflow.op.annotation.Operator; public class Zeros<T> implements Op, Operand<T> { /** - * Factory method for this operator + * Creates a zeroed tensor given its type and shape. * - * @param scope is a scope used to add the underlying operation. - * @param type the tensor datatype. - * @param shape the tensor shape. - * @return a constant initialized with zeros + * @param scope is a scope used to add the underlying operation + * @param dims a 1-D operand that represents the shape of the output tensor + * @param type the output tensor datatype + * @return a constant tensor initialized with zeros * @throws IllegalArgumentException if the tensor type or shape cannot be initialized with zeros. */ - public static <T> Zeros<T> create(Scope scope, Class<T> type, Shape shape) { - int numElements = shape.numElements(); - if (numElements < 0) { - throw new IllegalArgumentException("Only shapes with known dimension sizes can be used with zeroed constants"); + public static <T, U extends Number> Zeros<T> create(Scope scope, Operand<U> dims, Class<T> type) { + Scope childScope = scope.withSubScope("Zeros"); // If scope had an op name set, it will prevail on "Zeros" + int zeroSize = DataType.fromClass(type).byteSize(); + if (zeroSize < 0) { + throw new IllegalArgumentException(type.getSimpleName() + " tensors cannot be initialized with zeros"); } - int sizeInBytes = DataType.fromClass(type).sizeInBytes(); - if (sizeInBytes < 0) { - throw new IllegalArgumentException(type.getSimpleName() + " constants cannot be initialized with zeros"); - } - return new Zeros<T>(Constant.create(scope, type, shape, ByteBuffer.allocate(numElements * sizeInBytes))); + Constant<T> zero = Constant.create(childScope.withName("Zero"), type, new long[]{}, ByteBuffer.allocate(zeroSize)); + return new Zeros<T>(Fill.create(childScope, dims, zero)); } @Override public Output<T> asOutput() { - return constant.asOutput(); - } - - public Constant<T> constant() { - return constant; + return fill.asOutput(); } - private final Constant<T> constant; + private final Fill<T> fill; - private Zeros(Constant<T> constant) { - this.constant = constant; + private Zeros(Fill<T> fill) { + this.fill = fill; } } |