aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/java/src/main/java/org/tensorflow/op/core/Zeros.java
blob: cc46ce3c5bddfba9047035a301bec462f8c0a726 (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
package org.tensorflow.op.core;

import java.nio.ByteBuffer;

import org.tensorflow.DataType;
import org.tensorflow.Operand;
import org.tensorflow.Output;
import org.tensorflow.op.Op;
import org.tensorflow.op.Scope;
import org.tensorflow.op.annotation.Operator;

/**
 * An operator creating a constant initialized with zeros of the shape given by `dims`.
 * 
 * <p>For example, the following expression
 * <pre>{@code ops.zeros(ops.constant(new long[]{2, 2}), Float.class)</pre>
 * is the equivalent of
 * <pre>{@code ops.fill(ops.constant(new long[]{2, 2}), ops.constant(0.0f))</pre>
 *
 * @param <T> constant type
 */
@Operator
public class Zeros<T> implements Op, Operand<T> {

  /**
   * Creates a zeroed tensor given its type and shape.
   *
   * @param scope is a scope used to add the underlying operation
   * @param dims a 1-D operand that represents the shape of the output tensor
   * @param type the output tensor datatype
   * @return a constant tensor initialized with zeros
   * @throws IllegalArgumentException if the tensor type or shape cannot be initialized with zeros.
   */
  public static <T, U extends Number> Zeros<T> create(Scope scope, Operand<U> dims, Class<T> type) {
    Scope childScope = scope.withSubScope("Zeros"); // If scope had an op name set, it will prevail on "Zeros"
    int zeroSize = DataType.fromClass(type).byteSize();
    if (zeroSize < 0) {
      throw new IllegalArgumentException(type.getSimpleName() + " tensors cannot be initialized with zeros");
    }
    Constant<T> zero = Constant.create(childScope.withName("Zero"), type, new long[]{}, ByteBuffer.allocate(zeroSize));
    return new Zeros<T>(Fill.create(childScope, dims, zero));
  }

  @Override
  public Output<T> asOutput() {
    return fill.asOutput();
  }
  
  private final Fill<T> fill;
  
  private Zeros(Fill<T> fill) {
    this.fill = fill;
  }
}