blob: 5bba594e17c848eb004c1a499d2916a4e499df26 (
plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
|
package org.tensorflow.op.core;
import java.nio.ByteBuffer;
import org.tensorflow.DataType;
import org.tensorflow.Operand;
import org.tensorflow.Output;
import org.tensorflow.Shape;
import org.tensorflow.op.Op;
import org.tensorflow.op.Scope;
import org.tensorflow.op.annotation.Operator;
/**
* An operator creating a constant initialized with zeros w.r.t its type and shape.
*
* @param <T> constant type
*/
@Operator
public class Zeros<T> implements Op, Operand<T> {
/**
* Factory method for this operator
*
* @param scope is a scope used to add the underlying operation.
* @param type the tensor datatype.
* @param shape the tensor shape.
* @return a constant initialized with zeros
* @throws IllegalArgumentException if the tensor type or shape cannot be initialized with zeros.
*/
public static <T> Zeros<T> create(Scope scope, Class<T> type, Shape shape) {
int numElements = shape.numElements();
if (numElements < 0) {
throw new IllegalArgumentException("Only shapes with known dimension sizes can be used with zeroed constants");
}
int sizeInBytes = DataType.fromClass(type).sizeInBytes();
if (sizeInBytes < 0) {
throw new IllegalArgumentException(type.getSimpleName() + " constants cannot be initialized with zeros");
}
return new Zeros<T>(Constant.create(scope, type, shape, ByteBuffer.allocate(numElements * sizeInBytes)));
}
@Override
public Output<T> asOutput() {
return constant.asOutput();
}
public Constant<T> constant() {
return constant;
}
private final Constant<T> constant;
private Zeros(Constant<T> constant) {
this.constant = constant;
}
}
|