aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/compiler/xla/xla_data.proto
blob: eac8f2ff07e4a885affdc0f7b1563d3a2cb606d7 (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
584
585
586
587
588
589
590
591
592
593
594
595
596
597
598
599
600
601
602
603
604
605
606
607
608
609
610
611
612
613
614
615
616
617
618
619
620
621
622
623
624
625
626
627
628
629
630
631
632
633
634
635
636
637
638
639
640
641
642
643
644
645
646
647
648
649
650
651
652
653
654
655
656
657
658
659
660
661
662
663
664
665
666
667
668
669
670
671
672
673
674
675
676
677
678
679
680
681
682
683
684
685
686
687
688
689
690
691
692
693
694
695
696
697
698
699
700
701
702
703
704
705
706
707
708
709
710
711
712
713
714
715
716
717
718
719
720
721
722
723
724
725
726
727
728
729
730
731
732
733
734
735
736
737
738
739
740
741
742
743
744
745
746
747
748
749
750
751
752
753
754
755
756
757
758
759
760
761
762
763
764
765
766
767
768
769
770
771
772
773
774
775
776
777
778
779
780
781
782
783
784
785
786
787
788
789
790
791
792
793
794
795
796
797
798
799
800
801
802
803
804
805
806
807
808
809
810
811
812
813
814
815
816
817
818
819
820
821
822
823
824
825
826
827
828
829
830
831
832
833
834
835
836
837
838
839
840
841
842
843
844
845
846
847
848
849
850
851
852
853
854
855
856
857
858
859
860
861
862
863
864
865
866
867
868
869
870
871
872
873
874
875
876
877
878
879
880
881
882
883
884
885
886
887
888
889
890
891
892
893
894
895
896
897
898
899
900
901
902
903
904
905
906
907
908
909
910
911
912
/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.

Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at

    http://www.apache.org/licenses/LICENSE-2.0

Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/

syntax = "proto3";

package xla;
option cc_enable_arenas = true;

// Primitive types are the individual values that can be held in rectangular
// multidimensional arrays. A description of the rectangular multidimensional
// array dimensions / primitive type is given by Shape, below.
enum PrimitiveType {
  // Invalid primitive type to serve as default.
  PRIMITIVE_TYPE_INVALID = 0;

  // Predicates are two-state booleans.
  PRED = 1;

  // Signed integral values of fixed width.
  S8 = 2;
  S16 = 3;
  S32 = 4;
  S64 = 5;

  // Unsigned integral values of fixed width.
  U8 = 6;
  U16 = 7;
  U32 = 8;
  U64 = 9;

  // Floating-point values of fixed width.
  //
  // Note: if f16s are not natively supported on the device, they will be
  // converted to f16 from f32 at arbirary points in the computation.
  F16 = 10;
  F32 = 11;

  // Truncated 16 bit floating-point format. This is similar to IEEE's 16 bit
  // floating-point format, but uses 1 bit for the sign, 8 bits for the exponent
  // and 7 bits for the mantissa.
  BF16 = 16;

  F64 = 12;

  // Complex values of fixed width.
  C64 = 15;  // Paired F32 (real, imag), as in std::complex<float>.

  // A tuple is a polymorphic sequence; e.g. a shape that holds different
  // sub-shapes. They are used for things like returning multiple values from a
  // computation; e.g. a computation that returns weights and biases may have a
  // signature that results in a tuple like (f32[784x2000], f32[2000])
  //
  // If a shape proto has the tuple element type, it may not have any entries
  // in the dimensions field.
  TUPLE = 13;

  // An opaque type used for passing context specific data to a custom
  // operation.
  OPAQUE = 14;

  // Next = 17
}

// Describes the value held inside padding elements.
enum PaddingValue {
  INVALID_PAD = 0;

  // Zero padding must be 0-values that correspond to the shape's element type.
  ZERO_PAD = 1;

  // One padding must be 1-values that correspond to the shape's element type.
  ONE_PAD = 2;

  // "Lowest" padding must be the lowest values in the shape's element type,
  // used as padding for operations like max-accumulation.
  LOWEST_PAD = 3;

  // "Highest" padding must be the largest values in the shape's element type,
  // used as padding for operations like min-accumulation.
  HIGHEST_PAD = 4;

  // Unknown padding could be anything; e.g. floating NaNs!
  UNKNOWN_PAD = 5;
}

// Describes the padding configuration for Pad operation. The padding amount on
// both edges as well as between the elements are specified for each dimension.
message PaddingConfig {
  // Describes the padding configuration for a dimension.
  message PaddingConfigDimension {
    // Padding amount on the low-end (next to the index 0).
    int64 edge_padding_low = 1;

    // Padding amount on the high-end (next to the highest index).
    int64 edge_padding_high = 2;

    // Padding amount between the elements.
    int64 interior_padding = 3;
  }

  // The padding configuration for all dimensions.
  repeated PaddingConfigDimension dimensions = 1;
}

// A layout describes how the array is placed in (1D) memory space.  This
// includes the minor-to-major ordering of dimensions within a shape, as well as
// any padding present in those dimensions.
//
// Clients must specify the layouts of input Literals to the
// computation. Layouts specified in interior operations which take Shapes (for
// example, Convert) are ignored.
//
// See the XLA documentation for more information on shapes and layouts.
message Layout {
  // Sequence of dimension numbers, from minor (fastest varying index) to major
  // (slowest varying index). This field is required.
  repeated int64 minor_to_major = 1;

  // The width to which the layout of each dimension is padded up
  // to. If present, the size of the padded_dimensions must equal the
  // rank of the shape. The padding appears at the end of a dimension,
  // not at the beginning. This kind of padding, unlike padding in
  // e.g. convolution, is not part of the shape.
  repeated int64 padded_dimensions = 2;

  // Describes the values in the padding specified by
  // padded_dimensions.
  PaddingValue padding_value = 3;

  // Important: if any field is added, be sure to modify ShapeUtil::Equal()
  // appropriately to account for the new field.
}

// A shape describes the number of dimensions in the array, the size of each
// dimension, and the primitive component type.
//
// Tuples are a special case in that they have rank zero and have tuple_shapes
// defined.
//
// See the XLA documentation for more information on shapes and layouts.
message Shape {
  reserved 1;
  reserved "rank";

  // The element type for this shape.
  PrimitiveType element_type = 2;

  // The size (number of elements) for each dimension.
  // In XLA, dimensions are numbered from 0 to N-1 for an
  // N-dimensional array. The first element of 'dimensions' is the size of
  // dimension 0, the second element is the size of dimension 1, and so forth.
  // Empty list indicates a scalar.
  repeated int64 dimensions = 3;

  // For tuples only, the shapes of constitutent shapes in the tuple sequence.
  repeated Shape tuple_shapes = 4;

  // The layout used to back this shape.
  Layout layout = 5;

  // Important: if any field is added, be sure to modify ShapeUtil::Equal() and
  // ShapeUtil::Compatible() appropriately to account for the new field.
}

// Shape of the parameters and output of a computation (like a traditional
// function signature).
message ProgramShape {
  repeated Shape parameters = 1;
  Shape result = 2;
  repeated string parameter_names = 3;
}

// Statistics of a computation.
message ComputationStats {
  // The number of floating point operations in the computation.
  double flop_count = 1;

  // The number of transcendental operations (e.g., exp) in the computation.
  double transcendental_count = 2;
}

// Symbolization metadata for HLO Instructions.
//
// This metadata is used for debugging XLA code generation, as well as
// performance profiling of XLA-generated executables.
message OpMetadata {
  // The framework op name that generated this XLA op.
  //
  // Frameworks that build on top of XLA should mirror the names of their ops
  // back to users by specifying the op_type. In this way, even if the
  // framework's "ops" are implemented as multiple XLA HLO Ops, they can be
  // grouped appropriately. (e.g. if a SoftMax layer is emitted into XLA as
  // multiple ops, then each op should have the op_type be "SoftMax".)
  string op_type = 1;
  // The user-specified name of the op.
  //
  // This name is often unique within a computation. Note: some frameworks
  // add auto-generated names if the user does not provide one.
  string op_name = 2;
  // Indicate a file and line that this op is associated to in a user's program.
  //
  // e.g. it could be the file and line of user code that generated the op.
  string source_file = 3;
  int32 source_line = 4;
}

// Profile data from the execution of a computation.
message ExecutionProfile {
  // Whether the executable was read from the compilation cache.
  bool compilation_cache_hit = 1;

  // The time in milliseconds spent to compile the computation. This only set if
  // the executable was not read from the compilation cache
  // (compilation_cache_hit == false).
  int64 compile_time_ms = 2;

  // The number of cycles spent for the computation. This does not include the
  // time taken for the data transfers between the host and the device. This is
  // a target-dependent field and only used for debugging purposes.
  int64 compute_cycle_count = 3;

  // The time in nanoseconds spent for the computation, without data transfer.
  int64 compute_time_ns = 4;

  // The time in nanoseconds spent for the entire computation, including the
  // result data transfer time. Current implementation does not spend any cycles
  // for the input data transfer since the memory is initialized with the proper
  // values before the execution.
  int64 compute_and_transfer_time_ns = 5;
}

// Handle given to a user that represents a computation that the user builds up
// before execution.
message ComputationHandle {
  int64 handle = 1;
}

// Handle given to a user that represents an execution that the user launched
// asynchronously on the device.
message ExecutionHandle {
  int64 handle = 1;
}

// Handle given to a user that represents a globally accessible allocation.
// Contrast this against a ComputationDataHandle, which is not globally
// accessible, since it only exists within a specific computation.
message GlobalDataHandle {
  int64 handle = 1;
}

// Handle given to a user that represents a data result in a computation.
// This is used to pass to subsequent computations that depends upon the data as
// an operand.
message ComputationDataHandle {
  int64 handle = 1;
}

// Handle given to a user that represents a replicated virtual device. Each
// replicated device represents N physical devices for execution where N is the
// number of replicas.
message DeviceHandle {
  int64 handle = 1;

  // The number of model-parallel virtual devices that communicate via XLA
  // Send/Recv instructions.
  int64 device_count = 2;
}

// Handle given to a user to represent a channel between two computations
// via a Send and Recv instruction pair. Channels are unbuffered, so Send
// Send instructions will be blocked until the data is transferred.
message ChannelHandle {
  int64 handle = 1;
}

// DeviceAssignmentProto is a serialized form of DeviceAssignment class, which
// represents the device ids assigned to a set of replicated computations.
// See xla::DeviceAssignment class comment for more details.
message DeviceAssignmentProto {
  int32 replica_count = 1;
  int32 computation_count = 2;

  // Each logical computation runs on replica_count physical devices.
  // ComputationDevice represents the device ids assinged to the replicas.
  message ComputationDevice {
    repeated int32 replica_device_ids = 1;
  }
  repeated ComputationDevice computation_devices = 3;
}

// Literals are used when the server and client need to exchange materialized
// data / results. Literals are also used to describe constants used in
// computations.
//
// Transfers to/from the client are encoded in literal form, and the structure
// of the repeated fields is implied by the shape.
message LiteralProto {
  Shape shape = 1;
  repeated bool preds = 2;
  bytes u8s = 3;
  repeated int32 s32s = 4;
  repeated int64 s64s = 5;
  repeated uint32 u32s = 6;
  repeated uint64 u64s = 7;
  repeated float f32s = 8;
  repeated double f64s = 9;
  repeated float c64s = 12;  // Stored as interleaved real, imag floats.
  repeated LiteralProto tuple_literals = 10;
  // The F16s and BF16s are encoded in little endian byte order
  bytes f16s = 11;
  bytes bf16s = 13;
  // Next = 14
}

message WindowDimension {
  // The size of the window in this dimension. For a rectangle, this would be
  // the width or height.
  int64 size = 1;

  // The stride at which the window moves across the base area in this
  // dimension. In other words, this is the spacing between different
  // positions of the window in this dimension.
  int64 stride = 2;

  // If positive, means the amount of padding with zeroes to add to the base
  // area at the low end of this dimension; if negative, its negative means the
  // number of elements removed from the low end of this dimension. For example,
  // in the horizontal dimension of a rectangle, this would be the number of
  // zeroes to pad on the left, given that indices increase when going right.
  int64 padding_low = 3;

  // As padding_low, but on the high end of this dimension. For
  // example, in the horizontal dimension of a rectangle, this would
  // be the number of zeroes to pad on the right, given that indices
  // increase when going right.
  int64 padding_high = 4;

  // Dilation factor of the sliding window in this dimension. A dilation factor
  // of 1 means no dilation. window_dilation - 1 no-op entries ("holes") are
  // implicitly placed between each kernel element. See documentation for
  // convolution.
  int64 window_dilation = 5;

  // Dilation factor of the base area in this dimension. A dilation factor of 1
  // means no dilation. base_dilation - 1 no-op entries ("holes") are implicitly
  // placed between each base area element. See documentation for convolution.
  int64 base_dilation = 6;
}

// Describes the windowing in an operation such as convolution.
//
// The window is moved across a base area and for each position of the
// window a computation is performed. The field below describes the
// window and the movement of the window across a base area.
message Window {
  repeated WindowDimension dimensions = 1;
}

// Operation requests that are all collected as a tagged union with a oneof
// field in OpRequest.

message ConstantRequest {
  LiteralProto literal = 2;
}

message GetTupleElementRequest {
  ComputationDataHandle operand = 2;
  int64 index = 3;
}

message SliceRequest {
  ComputationDataHandle operand = 2;
  repeated int64 start_indices = 3;
  repeated int64 limit_indices = 4;
  repeated int64 strides = 5;
}

message DynamicSliceRequest {
  // Operand from which to slice at dynamic 'start_indices'.
  ComputationDataHandle operand = 2;
  // Dynamically computed 'start_indices' for slice operation.
  ComputationDataHandle start_indices = 3;
  // Slice sizes for each dimension (note that indices calculations are computed
  // modulo dimension sizes to avoid out-of-bound array accesses).
  repeated int64 slice_sizes = 4;
}

message DynamicUpdateSliceRequest {
  // Operand on which slice 'update' is to be applied.
  ComputationDataHandle operand = 2;
  // The slice update to apply to 'operand'.
  ComputationDataHandle update = 3;
  // Dynamically computed start indices for the update slice operation.
  ComputationDataHandle start_indices = 4;
}

message ConvolutionDimensionNumbers {
  // The number of the dimension that represents batch in the input.
  int64 input_batch_dimension = 7;

  // The number of the dimension that represents features in the input.
  int64 input_feature_dimension = 8;

  // The number of the dimension that represents batch in the output.
  int64 output_batch_dimension = 9;

  // The number of the dimension that represents features in the output.
  int64 output_feature_dimension = 10;

  // The dimension numbers for the spatial dimensions that the window
  // moves through in the input (lhs) and output.
  repeated int64 spatial_dimensions = 5;

  // The number of the dimension that represents input features in the
  // convolutional kernel (rhs).
  int64 kernel_input_feature_dimension = 3;

  // The number of the dimension that represents output features in
  // the convolutional kernel (rhs).
  int64 kernel_output_feature_dimension = 4;

  // The dimension numbers for the spatial dimensions that the window
  // moves through in the kernel (rhs). window.strides(0) is the
  // stride in the kernel_spatial_dimensions(0) dimension.
  repeated int64 kernel_spatial_dimensions = 6;
};

message ConvolveRequest {
  ComputationDataHandle lhs = 2;
  ComputationDataHandle rhs = 3;  // This is the filter/kernel.
  Window window = 4;              // Describes the filter/kenel.
  ConvolutionDimensionNumbers dimension_numbers = 5;
}

enum FftType {
  FFT = 0;    // Forward FFT; complex in, complex out.
  IFFT = 1;   // Inverse FFT; complex in, complex out.
  RFFT = 2;   // Forward real FFT; real in, fft_length / 2 + 1 complex out
  IRFFT = 3;  // Inverse real FFT; fft_length / 2 + 1 complex in,
              //                   fft_length real out
}

message FftRequest {
  FftType fft_type = 1;
  repeated int64 fft_length = 2;  // Multivalent for higher-order FFT.
  ComputationDataHandle operand = 3;
}

message InfeedRequest {
  // The shape of the data returned by reading the device's infeed buffer.
  Shape shape = 2;

  // Additional infeed configuration for the backend.
  bytes config = 3;
}

message OutfeedRequest {
  // The shape of the data returned by reading the device's outfeed buffer.
  Shape shape = 1;

  // Operand to the Outfeed. Supports tuple.
  ComputationDataHandle operand = 2;

  // Backend-specific information for how to perform the outfeed.
  bytes outfeed_config = 3;
}

message CallRequest {
  ComputationHandle to_apply = 2;
  repeated ComputationDataHandle operands = 3;
}

message CustomCallRequest {
  string call_target_name = 2;
  repeated ComputationDataHandle operands = 3;
  Shape shape = 4;
}

message MapRequest {
  repeated ComputationDataHandle operands = 2;
  ComputationHandle to_apply = 3;
  repeated ComputationDataHandle static_operands = 4;
  // The dimensions over which to map.
  // Example mapping a Dot operation along the batch dimension 0:
  //   operand0.shape = [2, 2, 2], operand1.shape = [2,2,3]
  //   Map({operand0, operand1}, Dot, {0})
  repeated int64 dimensions = 5;
}

message ReduceRequest {
  // Operand to the reduction.
  ComputationDataHandle operand = 2;

  // Initial value for the reduction. This must be consistent with the result
  // shape of to_apply.
  ComputationDataHandle init_value = 3;

  // The dimensions to reduce over.
  repeated int64 dimensions = 4;

  // The computation to apply in the reduction.
  ComputationHandle to_apply = 5;
}

message ReduceWindowRequest {
  ComputationDataHandle operand = 2;
  ComputationDataHandle init_value = 3;
  Window window = 4;
  ComputationHandle to_apply = 5;
}

message BatchNormTrainingRequest {
  ComputationDataHandle operand = 1;
  ComputationDataHandle scale = 2;
  ComputationDataHandle offset = 3;
  float epsilon = 4;
  int64 feature_index = 5;
}

message BatchNormInferenceRequest {
  ComputationDataHandle operand = 1;
  ComputationDataHandle scale = 2;
  ComputationDataHandle offset = 3;
  ComputationDataHandle mean = 4;
  ComputationDataHandle variance = 5;
  float epsilon = 6;
  int64 feature_index = 7;
}

message BatchNormGradRequest {
  ComputationDataHandle operand = 1;
  ComputationDataHandle scale = 2;
  ComputationDataHandle mean = 3;
  ComputationDataHandle variance = 4;
  ComputationDataHandle grad_output = 5;
  float epsilon = 6;
  int64 feature_index = 7;
}

message CrossReplicaSumRequest {
  ComputationDataHandle operand = 2;
}

message SelectAndScatterRequest {
  // Operand array on which the windows slide.
  ComputationDataHandle operand = 2;

  // Source array for the data to scatter.
  ComputationDataHandle source = 3;

  // Initial scalar value for each element in the output.
  ComputationDataHandle init_value = 4;

  // Window configuration.
  Window window = 5;

  // Binary function used to select an element from each window.
  ComputationHandle select = 6;

  // Binary function used to combine each scattered value from source with the
  // current output value at the selected location.
  ComputationHandle scatter = 7;
}

message ReverseRequest {
  ComputationDataHandle operand = 2;
  repeated int64 dimensions = 3;
}

message BroadcastRequest {
  ComputationDataHandle operand = 2;
  repeated int64 broadcast_sizes = 3;
}

message PadRequest {
  ComputationDataHandle operand = 2;
  ComputationDataHandle padding_value = 3;
  PaddingConfig padding_config = 4;
}

message ReshapeRequest {
  ComputationDataHandle operand = 2;

  // The dimension order for collapse (from fastest-changing to slowest).
  repeated int64 dimensions = 3;

  // The new dimension sizes (from dimension 0 to n-1).
  repeated int64 new_sizes = 4;
}

message TransposeRequest {
  ComputationDataHandle operand = 2;

  // The permutation of the operand's dimensions (in the range 0 to n-1).
  repeated int64 dimensions = 3;
}

message ParameterRequest {
  Shape shape = 2;
  int64 parameter = 3;
  string name = 4;
}

message GetLocalShapeRequest {
  ComputationHandle computation = 1;
  ComputationDataHandle operand = 2;
}

message GetLocalShapeResponse {
  Shape shape = 1;
}

message TraceRequest {
  string tag = 2;
  ComputationDataHandle operand = 3;
}

message ConvertRequest {
  ComputationDataHandle operand = 2;
  PrimitiveType new_element_type = 3;
}

message ConcatenateRequest {
  repeated ComputationDataHandle operands = 2;
  // The dimension in which we concatenate; e.g. if you had dimension arrays of
  // [4, 1] and [5, 1], you'd concatenate in dimension 0 to produce a [9, 1].
  // Attempting to concatenate those in dimension 1 would produce an error, as
  // 4 != 5 (and there is no ragged array support).
  int64 dimension = 3;
}

message WhileRequest {
  ComputationHandle condition = 2;
  ComputationHandle body = 3;
  ComputationDataHandle init = 4;
}

enum UnaryOperation {
  UNOP_INVALID = 0;

  // Elementwise, logical negation on booleans and bitwise negation on ints.
  UNOP_NOT = 1;

  // Elementwise, computes e^x.
  UNOP_EXP = 2;

  // Elementwise, computes -x.
  UNOP_NEGATE = 3;

  // Puts the elements in the operand into sorted order.
  UNOP_SORT = 4;

  // Elementwise, computes tanh(x).
  UNOP_TANH = 5;

  // Elementwise, computes the natural logarithm of x.
  UNOP_LOG = 6;

  // Elementwise, computes the floor of x.
  UNOP_FLOOR = 7;

  // Elementwise, computes the ceil of x.
  UNOP_CEIL = 8;

  // Elementwise, computes the abs of x.
  UNOP_ABS = 9;

  // Elementwise, computes the sign of x.
  UNOP_SIGN = 10;

  // Elementwise, tests if values are finite (not NaN or inf)
  UNOP_IS_FINITE = 11;

  // Elementwise, computes the cosine of x.
  UNOP_COS = 12;

  // Elementwise, computes the sine of x.
  UNOP_SIN = 13;

  // Elementwise, rounds x to nearest integral value, rounding half-way cases
  // away from zero.
  UNOP_ROUND_NEAREST_AFZ = 14;

  // Elementwise, extract real component of complex x.
  UNOP_REAL = 15;

  // Elementwise, extract real component of complex x.
  UNOP_IMAG = 16;
}

message UnaryOpRequest {
  UnaryOperation unop = 2;
  ComputationDataHandle operand = 3;
}

enum BinaryOperation {
  BINOP_INVALID = 0;

  // Arithmetic operations.
  BINOP_ADD = 1;
  BINOP_DIV = 2;
  BINOP_MUL = 3;
  BINOP_SUB = 4;

  // Comparison operators.
  BINOP_EQ = 5;
  BINOP_GE = 6;
  BINOP_GT = 7;
  BINOP_LE = 8;
  BINOP_LT = 9;
  BINOP_NE = 10;

  // Dot product, matrix multiply.
  BINOP_DOT = 12;

  // Element-wise maximum.
  BINOP_MAX = 14;

  // Element-wise minimum.
  BINOP_MIN = 15;

  // Raises the left-hand-side to the right-hand-side power.
  BINOP_POW = 16;

  // Remainder operation.
  BINOP_REM = 17;

  // Element-wise, logical operators on booleans and bitwise operators on ints.
  BINOP_AND = 18;
  BINOP_OR = 19;

  BINOP_SHIFT_LEFT = 20;
  BINOP_SHIFT_RIGHT_ARITHMETIC = 21;
  BINOP_SHIFT_RIGHT_LOGICAL = 22;

  // Complex from real, imag.
  BINOP_COMPLEX = 23;

  // Computes the 4-quadrant arctangent of the y, x input arguments.
  BINOP_ATAN2 = 24;
}

message BinaryOpRequest {
  BinaryOperation binop = 2;
  ComputationDataHandle lhs = 3;
  ComputationDataHandle rhs = 4;
  repeated int64 broadcast_dimensions = 5;
}

enum RandomDistribution {
  RNG_INVALID = 0;

  // Creates a uniform-distribution-generated random number on the semi-open
  // interval [parameter[0], parameter[1]).
  RNG_UNIFORM = 1;

  // Creates a normal-distribution-generated random number with mean
  // parameter[0] and standard deviation parameter[1].
  RNG_NORMAL = 2;

  // Creates a Bernoulli-distribution-generated random number with mean
  // parameter[0].
  RNG_BERNOULLI = 3;
}

message RngRequest {
  RandomDistribution distribution = 2;
  repeated ComputationDataHandle parameter = 3;
  Shape shape = 4;
}

enum TernaryOperation {
  TRIOP_INVALID = 0;

  // Given a predicate and two operands, selects operand0 if the predicate is
  // true and operand1 if the predicate is false.
  TRIOP_SELECT = 1;

  // Given a min, max and an operand returns the operand if between min and max,
  // else returns min if operand is less than min or max if operand is greater
  // than max.
  TRIOP_CLAMP = 3;
}

message TernaryOpRequest {
  TernaryOperation triop = 2;
  ComputationDataHandle lhs = 3;
  ComputationDataHandle rhs = 4;
  ComputationDataHandle ehs = 5;
}

enum VariadicOperation {
  VAROP_INVALID = 0;

  // Creates a tuple from its operands.
  VAROP_TUPLE = 1;
}

message VariadicOpRequest {
  VariadicOperation varop = 2;
  repeated ComputationDataHandle operands = 3;
}

message ReducePrecisionRequest {
  ComputationDataHandle operand = 1;
  int32 exponent_bits = 2;
  int32 mantissa_bits = 3;
}

message SendRequest {
  ComputationDataHandle operand = 1;
  ChannelHandle channel_handle = 2;
}

message RecvRequest {
  Shape shape = 1;
  ChannelHandle channel_handle = 2;
}

message OpSharding {
  enum Type {
    // This sharding is replicated across all devices (implies maximal,
    // all other fields are unused).
    REPLICATED = 0;
    // This sharding is maximal - one device runs the entire operation.
    MAXIMAL = 1;
    // This sharding is a tuple - only the tuple_shardings field is valid.
    TUPLE = 2;
    // None of the above; tile_shape and tile_assignment are both used.
    OTHER = 3;
  }
  Type type = 1;
  // The shape of the sharded tile.
  Shape tile_shape = 2;
  // The shape of the tile assignment tensor - this must be the same rank as
  // tile_shape and the product of its dimensions must equal
  // tile_assignment_devices.size().
  repeated int64 tile_assignment_dimensions = 3;
  // Flattened list of device IDs. The order of flattening is the same as used
  // by IndexUtil::MultiToLinearIndex(tile_assignment_shape).
  repeated int64 tile_assignment_devices = 4;
  // If type == TUPLE, the sub-shardings, one per leaf node in the tuple shape,
  // in pre-order. The tuple shape could be nested; here we store just a
  // flattened list of all leaves in the tuple shape. Note that the tuple shape
  // is not stored here; shardings do not store the shapes to which they are
  // applied, this is inferred from the instruction this sharding gets attached
  // to.
  repeated OpSharding tuple_shardings = 5;
}

message OpRequest {
  ComputationHandle computation = 1;
  OpMetadata metadata = 33;
  OpSharding sharding = 40;

  oneof op {
    BinaryOpRequest binary_op_request = 2;
    BroadcastRequest broadcast_request = 3;
    CallRequest call_request = 4;
    ConcatenateRequest concatenate_request = 5;
    ConstantRequest constant_request = 6;
    ConvertRequest convert_request = 7;
    ConvolveRequest convolve_request = 8;
    CrossReplicaSumRequest cross_replica_sum_request = 9;
    CustomCallRequest custom_call_request = 10;
    DynamicSliceRequest dynamic_slice_request = 11;
    DynamicUpdateSliceRequest dynamic_update_slice_request = 12;
    GetTupleElementRequest get_tuple_element_request = 13;
    InfeedRequest infeed_request = 14;
    MapRequest map_request = 15;
    PadRequest pad_request = 16;
    ParameterRequest parameter_request = 17;
    ReducePrecisionRequest reduce_precision_request = 36;
    ReduceRequest reduce_request = 18;
    ReduceWindowRequest reduce_window_request = 19;
    ReshapeRequest reshape_request = 20;
    ReverseRequest reverse_request = 21;
    RngRequest rng_request = 22;
    SelectAndScatterRequest select_and_scatter_request = 23;
    SliceRequest slice_request = 24;
    TernaryOpRequest ternary_op_request = 25;
    TraceRequest trace_request = 26;
    TransposeRequest transpose_request = 34;
    UnaryOpRequest unary_op_request = 27;
    VariadicOpRequest variadic_op_request = 28;
    WhileRequest while_request = 29;
    SendRequest send_request = 30;
    RecvRequest recv_request = 31;
    OutfeedRequest outfeed_request = 32;
    BatchNormTrainingRequest batch_norm_training_request = 35;
    BatchNormGradRequest batch_norm_grad_request = 37;
    BatchNormInferenceRequest batch_norm_inference_request = 38;
    FftRequest fft_request = 41;
    // Next: 42
  }
}

message OpResponse {
  ComputationDataHandle output = 1;
}