1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
|
package org.tensorflow.op.core;
import java.nio.ByteBuffer;
import org.tensorflow.DataType;
import org.tensorflow.Operand;
import org.tensorflow.Output;
import org.tensorflow.op.Op;
import org.tensorflow.op.Scope;
import org.tensorflow.op.annotation.Operator;
/**
* An operator creating a constant initialized with zeros of the shape given by `dims`.
*
* <p>For example, the following expression
* <pre>{@code ops.zeros(ops.constant(new long[]{2, 2}), Float.class)</pre>
* is the equivalent of
* <pre>{@code ops.fill(ops.constant(new long[]{2, 2}), ops.constant(0.0f))</pre>
*
* @param <T> constant type
*/
@Operator
public class Zeros<T> implements Op, Operand<T> {
/**
* Creates a zeroed tensor given its type and shape.
*
* @param scope is a scope used to add the underlying operation
* @param dims a 1-D operand that represents the shape of the output tensor
* @param type the output tensor datatype
* @return a constant tensor initialized with zeros
* @throws IllegalArgumentException if the tensor type or shape cannot be initialized with zeros.
*/
public static <T, U extends Number> Zeros<T> create(Scope scope, Operand<U> dims, Class<T> type) {
Scope childScope = scope.withSubScope("Zeros"); // If scope had an op name set, it will prevail on "Zeros"
int zeroSize = DataType.fromClass(type).byteSize();
if (zeroSize < 0) {
throw new IllegalArgumentException(type.getSimpleName() + " tensors cannot be initialized with zeros");
}
Constant<T> zero = Constant.create(childScope.withName("Zero"), type, new long[]{}, ByteBuffer.allocate(zeroSize));
return new Zeros<T>(Fill.create(childScope, dims, zero));
}
@Override
public Output<T> asOutput() {
return fill.asOutput();
}
private final Fill<T> fill;
private Zeros(Fill<T> fill) {
this.fill = fill;
}
}
|