aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/compiler/xla/xla_data.proto
blob: 73b3589dbf12341ddb3f3e819a550467a7b4d166 (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
584
585
586
587
588
589
590
591
592
593
594
595
/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.

Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at

    http://www.apache.org/licenses/LICENSE-2.0

Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/

syntax = "proto3";

package xla;
option cc_enable_arenas = true;

// Primitive types are the individual values that can be held in rectangular
// multidimensional arrays. A description of the rectangular multidimensional
// array dimensions / primitive type is given by Shape, below.
enum PrimitiveType {
  // Invalid primitive type to serve as default.
  PRIMITIVE_TYPE_INVALID = 0;

  // Predicates are two-state booleans.
  PRED = 1;

  // Signed integral values of fixed width.
  S8 = 2;
  S16 = 3;
  S32 = 4;
  S64 = 5;

  // Unsigned integral values of fixed width.
  U8 = 6;
  U16 = 7;
  U32 = 8;
  U64 = 9;

  // Floating-point values of fixed width.
  //
  // Note: if f16s are not natively supported on the device, they will be
  // converted to f16 from f32 at arbirary points in the computation.
  F16 = 10;
  F32 = 11;

  // Truncated 16 bit floating-point format. This is similar to IEEE's 16 bit
  // floating-point format, but uses 1 bit for the sign, 8 bits for the exponent
  // and 7 bits for the mantissa.
  BF16 = 16;

  F64 = 12;

  // Complex values of fixed width.
  C64 = 15;  // Paired F32 (real, imag), as in std::complex<float>.

  // A tuple is a polymorphic sequence; e.g. a shape that holds different
  // sub-shapes. They are used for things like returning multiple values from a
  // computation; e.g. a computation that returns weights and biases may have a
  // signature that results in a tuple like (f32[784x2000], f32[2000])
  //
  // If a shape proto has the tuple element type, it may not have any entries
  // in the dimensions field.
  TUPLE = 13;

  // An opaque type used for passing context-specific data to a custom
  // operation. Shapes of this primitive type will have empty dimensions and
  // tuple_shapes fields.
  OPAQUE = 14;

  // A token type threaded between side-effecting operations. Shapes of this
  // primitive type will have empty dimensions and tuple_shapes fields.
  TOKEN = 17;

  // Next = 18
}

// Describes the value held inside padding elements.
enum PaddingValue {
  INVALID_PAD = 0;

  // Zero padding must be 0-values that correspond to the shape's element type.
  ZERO_PAD = 1;

  // One padding must be 1-values that correspond to the shape's element type.
  ONE_PAD = 2;

  // "Lowest" padding must be the lowest values in the shape's element type,
  // used as padding for operations like max-accumulation.
  LOWEST_PAD = 3;

  // "Highest" padding must be the largest values in the shape's element type,
  // used as padding for operations like min-accumulation.
  HIGHEST_PAD = 4;

  // Unknown padding could be anything; e.g. floating NaNs!
  UNKNOWN_PAD = 5;
}

// Describes the padding configuration for Pad operation. The padding amount on
// both edges as well as between the elements are specified for each dimension.
message PaddingConfig {
  // Describes the padding configuration for a dimension.
  message PaddingConfigDimension {
    // Padding amount on the low-end (next to the index 0). May be negative.
    int64 edge_padding_low = 1;

    // Padding amount on the high-end (next to the highest index). May be
    // negative.
    int64 edge_padding_high = 2;

    // Padding amount between the elements. May not be negative.
    int64 interior_padding = 3;
  }

  // The padding configuration for all dimensions.
  repeated PaddingConfigDimension dimensions = 1;
}

// A format specifies the method used by a layout to store an array in memory.
enum Format {
  INVALID_FORMAT = 0;
  // The default layout, with exactly one storage location per element (ignoring
  // padding).
  DENSE = 1;
  // A sparsely encoded layout, providing only the index/value pairs of non-zero
  // elements.
  SPARSE = 2;
}

// A layout describes how the array is placed in (1D) memory space.  This
// includes the minor-to-major ordering of dimensions within a shape, as well as
// any padding present in those dimensions.
//
// Clients must specify the layouts of input Literals to the
// computation. Layouts specified in interior operations which take Shapes (for
// example, Convert) are ignored.
//
// See the XLA documentation for more information on shapes and layouts.
//
// LINT.IfChange
message Layout {
  // The method used to store the data in memory. The format determines which of
  // the other fields are used by the layout.
  Format format = 4;

  // Sequence of dimension numbers, from minor (fastest varying index) to major
  // (slowest varying index). This field is required.
  repeated int64 minor_to_major = 1;

  // The width to which the layout of each dimension is padded up to. If
  // present, the size of the padded_dimensions must equal the rank of the
  // shape. The padding appears at the end of a dimension, not at the
  // beginning. This kind of padding, unlike padding in e.g. convolution, is not
  // part of the shape. This field must be unset unless the format is DENSE.
  repeated int64 padded_dimensions = 2;

  // Describes the values in the padding specified by padded_dimensions. This
  // field must be unset unless the format is DENSE.
  PaddingValue padding_value = 3;

  // The maximum number of elements that can be stored for SPARSE formats.  This
  // can be used to determine the maximum size in bytes of arrays stored in
  // memory.  This field must be unset unless the format is SPARSE.
  int64 max_sparse_elements = 5;

  // Important: if any field is added, be sure to modify ShapeUtil::Equal() and
  // LayoutUtil::Hash appropriately to account for the new field.
}
// LINT.ThenChange( \
//     https://www.tensorflow.org/code/tensorflow/compiler/xla/shape_util.cc,      \
//     https://www.tensorflow.org/code/tensorflow/compiler/xla/layout_util.cc)

// A shape describes the number of dimensions in the array, the size of each
// dimension, and the primitive component type.
//
// Tuples are a special case in that they have rank zero and have tuple_shapes
// defined.
//
// See the XLA documentation for more information on shapes and layouts.
//
// LINT.IfChange
message Shape {
  reserved 1;
  reserved "rank";

  // The element type for this shape.
  PrimitiveType element_type = 2;

  // The size (number of elements) for each dimension.
  // In XLA, dimensions are numbered from 0 to N-1 for an
  // N-dimensional array. The first element of 'dimensions' is the size of
  // dimension 0, the second element is the size of dimension 1, and so forth.
  // Empty list indicates a scalar.
  repeated int64 dimensions = 3;

  // For tuples only, the shapes of constitutent shapes in the tuple sequence.
  repeated Shape tuple_shapes = 4;

  // The layout used to back this shape.
  Layout layout = 5;

  // Important: if any field is added, be sure to modify ShapeUtil::Equal(),
  // ShapeUtil::Compatible() and ShapeUtil::Hash() appropriately to account for
  // the new field.
}
// LINT.ThenChange( \
//     https://www.tensorflow.org/code/tensorflow/compiler/xla/shape_util.cc)

// Shape of the parameters and output of a computation (like a traditional
// function signature).
message ProgramShape {
  repeated Shape parameters = 1;
  Shape result = 2;
  repeated string parameter_names = 3;
}

// Statistics of a computation.
message ComputationStats {
  // The number of floating point operations in the computation.
  double flop_count = 1;

  // The number of transcendental operations (e.g., exp) in the computation.
  double transcendental_count = 2;
}

// Symbolization metadata for HLO Instructions.
//
// This metadata is used for debugging XLA code generation, as well as
// performance profiling of XLA-generated executables.
message OpMetadata {
  // The framework op name that generated this XLA op.
  //
  // Frameworks that build on top of XLA should mirror the names of their ops
  // back to users by specifying the op_type. In this way, even if the
  // framework's "ops" are implemented as multiple XLA HLO Ops, they can be
  // grouped appropriately. (e.g. if a SoftMax layer is emitted into XLA as
  // multiple ops, then each op should have the op_type be "SoftMax".)
  string op_type = 1;
  // The user-specified name of the op.
  //
  // This name is often unique within a computation. Note: some frameworks
  // add auto-generated names if the user does not provide one.
  string op_name = 2;
  // Indicate a file and line that this op is associated to in a user's program.
  //
  // e.g. it could be the file and line of user code that generated the op.
  string source_file = 3;
  int32 source_line = 4;
}

// Profile data from the execution of a computation.
message ExecutionProfile {
  // Whether the executable was read from the compilation cache.
  bool compilation_cache_hit = 1;

  // The time in milliseconds spent to compile the computation. This only set if
  // the executable was not read from the compilation cache
  // (compilation_cache_hit == false).
  int64 compile_time_ms = 2;

  // The number of cycles spent for the computation. This does not include the
  // time taken for the data transfers between the host and the device. This is
  // a target-dependent field and only used for debugging purposes.
  int64 compute_cycle_count = 3;

  // The time in nanoseconds spent for the computation, without data transfer.
  int64 compute_time_ns = 4;

  // The time in nanoseconds spent for the entire computation, including the
  // result data transfer time. Current implementation does not spend any cycles
  // for the input data transfer since the memory is initialized with the proper
  // values before the execution.
  int64 compute_and_transfer_time_ns = 5;

  // The size of the binary code in the executable.
  int64 executable_size_in_bytes = 6;
}

// Handle given to a user that represents an execution that the user launched
// asynchronously on the device.
message ExecutionHandle {
  int64 handle = 1;
}

// Handle given to a user that represents a globally accessible allocation.
// Contrast this against a ComputationDataHandle, which is not globally
// accessible, since it only exists within a specific computation.
message GlobalDataHandle {
  int64 handle = 1;
}

// Handle given to a user that represents a replicated virtual device. Each
// replicated device represents N physical devices for execution where N is the
// number of replicas.
message DeviceHandle {
  int64 handle = 1;

  // The number of model-parallel virtual devices that communicate via XLA
  // Send/Recv instructions.
  int64 device_count = 2;
}

// Handle given to a user to represent a channel between two computations
// via a Send and Recv instruction pair. Channels are unbuffered, so Send
// Send instructions will be blocked until the data is transferred.
message ChannelHandle {
  int64 handle = 1;
  enum ChannelType {
    // Invalid primitive type to serve as default.
    CHANNEL_TYPE_INVALID = 0;

    // A channel for sending data between devices.
    DEVICE_TO_DEVICE = 1;

    // A channel for sending data from the device to the host. Can only be used
    // with a Send operation.
    DEVICE_TO_HOST = 2;

    // A channel for sending data from the host to the device. Can only be used
    // with a Recv operation.
    HOST_TO_DEVICE = 3;
  }
  ChannelType type = 2;
}

// DeviceAssignmentProto is a serialized form of DeviceAssignment class, which
// represents the device ids assigned to a set of replicated computations.
// See xla::DeviceAssignment class comment for more details.
message DeviceAssignmentProto {
  int32 replica_count = 1;
  int32 computation_count = 2;

  // Each logical computation runs on replica_count physical devices.
  // ComputationDevice represents the device ids assinged to the replicas.
  message ComputationDevice {
    repeated int32 replica_device_ids = 1;
  }
  repeated ComputationDevice computation_devices = 3;
}

// Literals are used when the server and client need to exchange materialized
// data / results. Literals are also used to describe constants used in
// computations.
//
// Transfers to/from the client are encoded in literal form, and the structure
// of the repeated fields is implied by the shape.
message LiteralProto {
  Shape shape = 1;
  repeated bool preds = 2;
  bytes s8s = 15;
  bytes u8s = 3;
  repeated int32 s32s = 4;
  repeated int64 s64s = 5;
  repeated uint32 u32s = 6;
  repeated uint64 u64s = 7;
  repeated float f32s = 8;
  repeated double f64s = 9;
  repeated float c64s = 12;  // Stored as interleaved real, imag floats.
  repeated LiteralProto tuple_literals = 10;
  // The F16s and BF16s are encoded in little endian byte order
  bytes f16s = 11;
  bytes bf16s = 13;
  repeated int64 sparse_indices = 14;
  // Next = 16
}

message WindowDimension {
  // The size of the window in this dimension. For a rectangle, this would be
  // the width or height.
  int64 size = 1;

  // The stride at which the window moves across the base area in this
  // dimension. In other words, this is the spacing between different
  // positions of the window in this dimension.
  int64 stride = 2;

  // If positive, means the amount of padding to add to the base area at the low
  // end of this dimension; if negative, its negative means the number of
  // elements removed from the low end of this dimension. For example, in the
  // horizontal dimension of a rectangle, this would be the number of padding
  // values to pad on the left, given that indices increase when going right.
  // The actual padding value depends upon the context. Convolution pads with
  // zeros. ReduceWindow and SelectAndScatter pads with the reduce function's
  // init value.
  int64 padding_low = 3;

  // As padding_low, but on the high end of this dimension. For example, in the
  // horizontal dimension of a rectangle, this would be the number of values to
  // pad on the right, given that indices increase when going right.
  int64 padding_high = 4;

  // Dilation factor of the sliding window in this dimension. A dilation factor
  // of 1 means no dilation. window_dilation - 1 no-op entries ("holes") are
  // implicitly placed between each kernel element. This value may not be less
  // than 1. See documentation for convolution.
  int64 window_dilation = 5;

  // Dilation factor of the base area in this dimension. A dilation factor of 1
  // means no dilation. base_dilation - 1 no-op entries ("holes") are implicitly
  // placed between each base area element. This value may not be less than 1.
  // See documentation for convolution.
  int64 base_dilation = 6;

  // Window reversal means that this dimension was logically reversed before the
  // operation.
  bool window_reversal = 7;
}

// Describes the windowing in an operation such as convolution.
//
// The window is moved across a base area and for each position of the
// window a computation is performed. The field below describes the
// window and the movement of the window across a base area.
message Window {
  repeated WindowDimension dimensions = 1;
}

// Describes the dimension numbers for a gather operation.
//
// See https://www.tensorflow.org/performance/xla/operation_semantics#gather for
// more details.
message GatherDimensionNumbers {
  // "Window indices" is a term for a set of indices that index into the
  // interior of a dynamic-slice from the input tensor, the starting indices for
  // which were computed from output_gather_dims (see the operation semantic for
  // how this is defined) and the start_indices tensor.
  //
  // The window indices for a specific output index Out is computed as:
  //
  //  i = 0
  //  for (k : [0, input_tensor_shape.rank))
  //    window_indices[k] =
  //      if k in collapsed_slice_dims
  //      then 0
  //      else Out[offset_dims[i++]]
  repeated int64 offset_dims = 1;
  repeated int64 collapsed_slice_dims = 2;

  // This is interpreted as a map from i to start_index_map[i]. It
  // transforms the gather index looked up from the start_indices tensor into
  // the starting index in the input space.
  repeated int64 start_index_map = 3;

  // The dimension in the start_indices input that contains the starting
  // indices.
  int64 index_vector_dim = 4;
}

// Describes the dimension numbers for a scatter operation.
//
// All the fields are similar to the corresponding fields in
// GatherDimensionNumbers. Differences are noted below.
message ScatterDimensionNumbers {
  // The set of dimensions in the updates shape that are window dimensions.
  repeated int64 update_window_dims = 1;
  // The set of window dimensions that must be inserted into the updates shape.
  repeated int64 inserted_window_dims = 2;

  repeated int64 scatter_dims_to_operand_dims = 3;
  int64 index_vector_dim = 4;
}

message ConvolutionDimensionNumbers {
  // The number of the dimension that represents batch in the input.
  int64 input_batch_dimension = 7;

  // The number of the dimension that represents features in the input.
  int64 input_feature_dimension = 8;

  // The dimension numbers for the spatial dimensions that the window
  // moves through in the input.
  repeated int64 input_spatial_dimensions = 11;

  // The number of the dimension that represents input features in the
  // convolutional kernel (rhs).
  int64 kernel_input_feature_dimension = 3;

  // The number of the dimension that represents output features in
  // the convolutional kernel (rhs).
  int64 kernel_output_feature_dimension = 4;

  // The dimension numbers for the spatial dimensions that the window
  // moves through in the kernel (rhs). window.strides(0) is the
  // stride in the kernel_spatial_dimensions(0) dimension.
  repeated int64 kernel_spatial_dimensions = 6;

  // The number of the dimension that represents batch in the output.
  int64 output_batch_dimension = 9;

  // The number of the dimension that represents features in the output.
  int64 output_feature_dimension = 10;

  // The dimension numbers for the spatial dimensions that the window
  // moves through in the output.
  repeated int64 output_spatial_dimensions = 12;

  // Next = 13
};

enum FftType {
  FFT = 0;    // Forward FFT; complex in, complex out.
  IFFT = 1;   // Inverse FFT; complex in, complex out.
  RFFT = 2;   // Forward real FFT; real in, fft_length / 2 + 1 complex out
  IRFFT = 3;  // Inverse real FFT; fft_length / 2 + 1 complex in,
              //                   fft_length real out
}

message DotDimensionNumbers {
  // The dimension numbers that represent the 'lhs' contracting dimensions.
  repeated int64 lhs_contracting_dimensions = 1;
  // The dimension numbers that represent the 'rhs' contracting dimensions.
  repeated int64 rhs_contracting_dimensions = 2;
  // The dimension numbers that represent the 'lhs' batch dimensions.
  repeated int64 lhs_batch_dimensions = 3;
  // The dimension numbers that represent the 'rhs' batch dimensions.
  repeated int64 rhs_batch_dimensions = 4;
};

enum RandomDistribution {
  RNG_INVALID = 0;

  // Creates a uniform-distribution-generated random number on the semi-open
  // interval [parameter[0], parameter[1]).
  RNG_UNIFORM = 1;

  // Creates a normal-distribution-generated random number with mean
  // parameter[0] and standard deviation parameter[1].
  RNG_NORMAL = 2;

  // Next: 4
}

message OpSharding {
  enum Type {
    // This sharding is replicated across all devices (implies maximal,
    // all other fields are unused).
    REPLICATED = 0;
    // This sharding is maximal - one device runs the entire operation.
    MAXIMAL = 1;
    // This sharding is a tuple - only the tuple_shardings field is valid.
    TUPLE = 2;
    // None of the above; tile_shape and tile_assignment are both used.
    OTHER = 3;
  }
  Type type = 1;
  // The shape of the sharded tile.
  Shape tile_shape = 2;
  // The shape of the tile assignment tensor - this must be the same rank as
  // tile_shape and the product of its dimensions must equal
  // tile_assignment_devices.size().
  repeated int64 tile_assignment_dimensions = 3;
  // Flattened list of device IDs. The order of flattening is the same as used
  // by IndexUtil::MultiToLinearIndex(tile_assignment_shape).
  repeated int64 tile_assignment_devices = 4;
  // If type == TUPLE, the sub-shardings, one per leaf node in the tuple shape,
  // in pre-order. The tuple shape could be nested; here we store just a
  // flattened list of all leaves in the tuple shape. Note that the tuple shape
  // is not stored here; shardings do not store the shapes to which they are
  // applied, this is inferred from the instruction this sharding gets attached
  // to.
  repeated OpSharding tuple_shardings = 5;
}

// Describes the replica groups in a cross replica op (e.g., all-reduce and
// all-to-all).
message ReplicaGroup {
  // The ids of the replicas that belongs to the same group. The ordering of the
  // ids matters in some op (e.g., all-to-all).
  repeated int64 replica_ids = 1;
}

// Describes the source target pair in the collective permute op.
message SourceTarget {
  int64 source = 1;
  int64 target = 2;
}

// Used to indicate the precision configuration. It has backend specific
// meaning.
message PrecisionConfig {
  enum Precision {
    DEFAULT = 0;
    HIGH = 1;
    HIGHEST = 2;

    // Next: 3
  }
  repeated Precision operand_precision = 1;

  // Next: 2
}