aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/java/src/main/java/org/tensorflow/Shape.java
blob: d99b0078f62b2178f8fd5078729a0fa405d3d9e7 (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
/* Copyright 2016 The TensorFlow Authors. All Rights Reserved.

Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at

    http://www.apache.org/licenses/LICENSE-2.0

Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/

package org.tensorflow;

import java.util.Arrays;

/** The possibly partially known shape of a tensor produced by an operation. */
public final class Shape {

  /** Create a Shape representing an unknown number of dimensions. */
  public static Shape unknown() {
    return new Shape(null);
  }

  /** Create a Shape representing a scalar value. */
  public static Shape scalar() {
    return new Shape(new long[0]);
  }

  /**
   * Create a Shape representing an N-dimensional value.
   *
   * <p>Creates a Shape representing an N-dimensional value (N being at least 1), with the provided
   * size for each dimension. A -1 indicates that the size of the corresponding dimension is
   * unknown. For example:
   *
   * <pre>{@code
   * // A 2-element vector.
   * Shape vector = Shape.create(2);
   *
   * // A 2x3 matrix.
   * Shape matrix = Shape.create(2, 3);
   *
   * // A matrix with 4 columns but an unknown number of rows.
   * // This is typically used to indicate the shape of tensors that represent
   * // a variable-sized batch of values. The Shape below might represent a
   * // variable-sized batch of 4-element vectors.
   * Shape batch = Shape.create(-1, 4);
   * }</pre>
   */
  public static Shape make(long firstDimensionSize, long... otherDimensionSizes) {
    long[] shape = new long[otherDimensionSizes.length + 1];
    shape[0] = firstDimensionSize;
    System.arraycopy(otherDimensionSizes, 0, shape, 1, otherDimensionSizes.length);
    return new Shape(shape);
  }

  /**
   * Number of dimensions represented by this shape.
   *
   * @return -1 if the number of dimensions is unknown, 0 if the shape represents a scalar, 1 for a
   *     vector, 2 for a matrix etc.
   */
  public int numDimensions() {
    return shape == null ? -1 : shape.length;
  }

  /**
   * The size of the i-th dimension.
   *
   * @return The size of the requested dimension or -1 if it is unknown.
   */
  public long size(int i) {
    return shape == null ? -1 : shape[i];
  }

  /**
   * The total number of elements found in a tensor of this shape.
   * 
   * <p>If the size of some dimensions is unknown, the total number of elements cannot be calculated and -1 is returned.
   * 
   * @return the number of elements or -1 if size of some dimension are unknown
   */
  public int numElements() {
    if (shape == null) {
       return -1;
    }
    int total = 1;
    for (int i = 0; i < shape.length; ++i) {
      // TODO (karllessard): There might be a lossy conversion here from 'long' sizes to 'int' total, but this issue
      // seems ubiquitous in the current Java client implementation. It should be adressed all at once.
      int size = (int) size(i);
      if (size < 0) {
        return -1;
      }
      total *= size;
    }
    return total;
  }

  /**
   * Returns the shape as an array.
   * 
   * <p>Each element represent the size of the dimension at the given index. For example,
   * {@code shape.asArray()[2]} is equal to the size of the third dimension of this shape.
   */
  public long[] asArray() {
    return shape;
  }

  @Override
  public int hashCode() {
    return Arrays.hashCode(shape);
  }

  @Override
  public boolean equals(Object obj) {
    if (this == obj) {
      return true;
    }

    if (obj instanceof Shape && Arrays.equals(this.shape, ((Shape) obj).shape)) {
      return !hasUnknownDimension();
    }

    return super.equals(obj);
  }

  /** Succinct description of the shape meant for debugging. */
  @Override
  public String toString() {
    if (shape == null) {
      return "<unknown>";
    }
    return Arrays.toString(shape).replace("-1", "?");
  }

  // Package-private constructor.
  Shape(long[] shape) {
    this.shape = shape;
  }

  private long[] shape;

  private boolean hasUnknownDimension() {
    if (shape == null) {
      return true;
    }

    for (long dimension : shape) {
      if (dimension == -1) {
        return true;
      }
    }

    return false;
  }
}