aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/compiler/xla/client/xla_builder.h
blob: 9ceede7a795168716120e74f18a8053390e92801 (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
584
585
586
587
588
589
590
591
592
593
594
595
596
597
598
599
600
601
602
603
604
605
606
607
608
609
610
611
612
613
614
615
616
617
618
619
620
621
622
623
624
625
626
627
628
629
630
631
632
633
634
635
636
637
638
639
640
641
642
643
644
645
646
647
648
649
650
651
652
653
654
655
656
657
658
659
660
661
662
663
664
665
666
667
668
669
670
671
672
673
674
675
676
677
678
679
680
681
682
683
684
685
686
687
688
689
690
691
692
693
694
695
696
697
698
699
700
701
702
703
704
705
706
707
708
709
710
711
712
713
714
715
716
717
718
719
720
721
722
723
724
725
726
727
728
729
730
731
732
733
734
735
736
737
738
739
740
741
742
743
744
745
746
747
748
749
750
751
752
753
754
755
756
757
758
759
760
761
762
763
764
765
766
767
768
769
770
771
772
773
774
775
776
777
778
779
780
781
782
783
784
785
786
787
788
789
790
791
792
793
794
795
796
797
798
799
800
801
802
803
804
805
806
807
808
809
810
811
812
813
814
815
816
817
818
819
820
821
822
823
824
825
826
827
828
829
830
831
832
833
834
835
836
837
838
839
840
841
842
843
844
845
846
847
848
849
850
851
852
853
854
855
856
857
858
859
860
861
862
863
864
865
866
867
868
869
870
871
872
873
874
875
876
877
878
879
880
881
882
883
884
885
886
887
888
889
890
891
892
893
894
895
896
897
898
899
900
901
902
903
904
905
906
907
908
909
910
911
912
913
914
915
916
917
918
919
920
921
922
923
924
925
926
927
928
929
930
931
932
933
934
935
936
937
938
939
940
941
942
943
944
945
946
947
948
949
950
951
952
953
954
955
956
957
958
959
960
961
962
963
964
965
966
967
968
969
970
971
972
973
974
975
976
977
978
979
980
981
982
983
984
985
986
987
988
989
990
991
992
993
994
995
996
997
998
999
1000
1001
1002
1003
1004
1005
1006
1007
1008
1009
1010
1011
1012
1013
1014
1015
1016
1017
1018
1019
1020
1021
1022
1023
1024
1025
1026
1027
1028
1029
1030
1031
1032
1033
1034
1035
1036
1037
1038
1039
1040
1041
1042
1043
1044
1045
1046
1047
1048
1049
1050
1051
1052
1053
1054
1055
1056
1057
1058
1059
1060
1061
1062
1063
1064
1065
1066
1067
1068
1069
1070
1071
1072
1073
1074
1075
1076
1077
1078
1079
1080
1081
1082
1083
1084
1085
1086
1087
1088
1089
1090
1091
1092
1093
1094
1095
1096
1097
1098
1099
1100
1101
1102
1103
1104
1105
1106
1107
1108
1109
1110
1111
1112
1113
1114
1115
1116
1117
1118
1119
1120
1121
1122
1123
1124
1125
1126
1127
1128
1129
1130
1131
1132
1133
1134
1135
1136
1137
1138
1139
1140
1141
1142
1143
1144
1145
1146
1147
1148
1149
1150
1151
1152
1153
1154
1155
1156
1157
1158
1159
1160
1161
1162
1163
1164
1165
1166
1167
1168
1169
1170
1171
1172
1173
1174
1175
1176
1177
1178
1179
1180
1181
1182
1183
1184
1185
1186
1187
1188
1189
1190
1191
1192
1193
1194
1195
1196
1197
1198
1199
1200
1201
1202
1203
1204
1205
1206
1207
1208
1209
1210
1211
1212
1213
1214
1215
1216
1217
1218
1219
1220
1221
1222
1223
1224
1225
1226
1227
1228
1229
1230
1231
1232
1233
1234
1235
1236
1237
1238
1239
1240
1241
1242
1243
1244
1245
1246
1247
1248
1249
1250
1251
1252
1253
1254
1255
1256
1257
1258
1259
1260
1261
1262
1263
1264
1265
1266
1267
1268
1269
1270
1271
1272
1273
1274
1275
1276
1277
1278
1279
1280
1281
1282
1283
1284
1285
1286
1287
1288
1289
1290
1291
1292
1293
1294
1295
1296
1297
1298
1299
1300
1301
1302
1303
1304
1305
1306
1307
1308
1309
1310
1311
1312
1313
1314
1315
1316
1317
1318
1319
1320
1321
1322
1323
1324
1325
1326
1327
1328
1329
1330
1331
1332
1333
1334
1335
1336
1337
1338
1339
1340
1341
1342
1343
1344
1345
1346
1347
1348
1349
1350
1351
1352
1353
1354
1355
1356
1357
1358
1359
1360
1361
1362
1363
1364
1365
1366
1367
1368
1369
1370
1371
1372
1373
1374
1375
1376
1377
1378
1379
1380
1381
1382
1383
1384
1385
1386
1387
1388
1389
1390
1391
1392
1393
1394
1395
1396
1397
1398
1399
1400
1401
1402
1403
1404
1405
1406
1407
1408
1409
1410
1411
1412
1413
1414
1415
1416
1417
1418
1419
1420
1421
1422
1423
1424
1425
1426
1427
1428
1429
1430
1431
1432
1433
1434
1435
1436
1437
1438
1439
1440
1441
1442
1443
1444
1445
1446
1447
1448
1449
1450
1451
1452
1453
1454
1455
1456
1457
1458
1459
1460
1461
1462
1463
1464
1465
1466
1467
1468
1469
1470
1471
1472
1473
1474
1475
1476
1477
1478
1479
1480
1481
1482
1483
1484
1485
1486
1487
1488
1489
1490
1491
1492
1493
1494
1495
1496
1497
1498
1499
1500
1501
1502
1503
1504
1505
1506
1507
1508
1509
1510
1511
1512
1513
1514
1515
1516
1517
1518
1519
1520
1521
1522
1523
1524
1525
1526
1527
1528
1529
1530
1531
1532
1533
1534
1535
1536
1537
1538
1539
1540
1541
1542
1543
1544
1545
1546
1547
1548
1549
1550
1551
1552
1553
1554
1555
1556
1557
1558
1559
1560
1561
1562
1563
1564
1565
1566
1567
1568
1569
1570
1571
1572
1573
1574
1575
1576
1577
1578
1579
1580
1581
1582
1583
1584
1585
1586
1587
1588
1589
1590
1591
1592
1593
1594
1595
1596
1597
1598
1599
1600
1601
1602
1603
1604
1605
1606
1607
1608
1609
1610
1611
1612
1613
1614
1615
1616
1617
1618
1619
1620
1621
1622
1623
1624
1625
1626
1627
1628
1629
1630
1631
1632
1633
1634
1635
1636
1637
1638
1639
1640
1641
1642
1643
1644
1645
1646
1647
1648
1649
1650
1651
1652
1653
1654
1655
1656
1657
1658
1659
1660
1661
1662
1663
1664
1665
1666
1667
1668
1669
1670
1671
1672
1673
1674
1675
1676
1677
1678
1679
1680
1681
1682
1683
1684
1685
1686
1687
1688
1689
1690
1691
1692
1693
1694
1695
1696
1697
1698
1699
1700
1701
1702
1703
1704
1705
1706
1707
1708
1709
1710
1711
1712
1713
1714
1715
1716
1717
1718
1719
1720
1721
1722
1723
1724
1725
1726
1727
1728
1729
1730
1731
1732
1733
1734
1735
1736
1737
1738
1739
1740
1741
1742
1743
1744
1745
1746
1747
1748
1749
1750
1751
1752
1753
1754
1755
1756
1757
1758
1759
1760
1761
1762
1763
1764
1765
1766
1767
1768
1769
1770
1771
1772
1773
1774
1775
1776
1777
1778
1779
1780
1781
1782
1783
1784
1785
1786
1787
1788
1789
1790
1791
1792
1793
1794
1795
1796
1797
1798
1799
1800
1801
1802
1803
1804
1805
1806
1807
1808
1809
1810
1811
1812
1813
1814
1815
1816
1817
1818
1819
1820
1821
1822
1823
1824
1825
1826
1827
1828
1829
1830
1831
1832
1833
1834
1835
1836
1837
1838
1839
1840
1841
1842
1843
1844
1845
1846
1847
1848
1849
1850
1851
1852
1853
1854
1855
1856
1857
1858
1859
1860
1861
1862
1863
1864
1865
1866
1867
1868
1869
1870
1871
1872
1873
1874
1875
1876
1877
1878
1879
1880
1881
1882
1883
1884
1885
1886
1887
1888
1889
1890
1891
1892
1893
1894
1895
1896
1897
1898
1899
1900
1901
1902
1903
1904
1905
1906
1907
1908
1909
1910
1911
1912
1913
1914
1915
1916
1917
1918
1919
1920
1921
1922
1923
1924
1925
1926
1927
1928
1929
1930
1931
1932
1933
1934
1935
1936
1937
1938
1939
1940
1941
1942
1943
1944
1945
1946
1947
1948
1949
1950
1951
1952
1953
1954
1955
1956
1957
1958
1959
1960
1961
1962
1963
1964
1965
1966
1967
1968
1969
1970
1971
1972
1973
1974
1975
1976
1977
1978
1979
1980
1981
1982
1983
1984
1985
1986
1987
1988
1989
1990
1991
1992
1993
1994
1995
1996
1997
1998
1999
2000
2001
2002
2003
2004
2005
2006
2007
2008
2009
2010
2011
2012
2013
2014
2015
2016
2017
2018
2019
2020
2021
2022
2023
2024
2025
2026
2027
2028
2029
2030
2031
2032
2033
2034
2035
2036
2037
2038
2039
2040
2041
2042
2043
2044
2045
2046
2047
2048
2049
2050
2051
2052
2053
2054
2055
2056
2057
2058
2059
2060
2061
2062
2063
2064
2065
2066
2067
2068
2069
2070
2071
2072
2073
2074
2075
2076
2077
2078
2079
2080
2081
2082
2083
2084
2085
2086
2087
2088
2089
2090
2091
2092
2093
2094
2095
2096
2097
2098
2099
2100
2101
2102
2103
2104
2105
2106
2107
2108
2109
2110
2111
2112
2113
2114
2115
2116
2117
2118
2119
2120
2121
2122
2123
2124
2125
2126
2127
2128
2129
2130
2131
2132
2133
2134
2135
2136
2137
2138
2139
2140
2141
2142
2143
2144
2145
2146
2147
2148
2149
2150
2151
2152
2153
2154
2155
2156
2157
2158
2159
2160
2161
2162
2163
2164
2165
2166
2167
2168
2169
2170
2171
2172
2173
2174
2175
2176
2177
2178
2179
2180
2181
2182
2183
2184
2185
2186
2187
2188
2189
2190
2191
2192
2193
2194
2195
2196
2197
2198
2199
2200
2201
2202
2203
2204
2205
2206
2207
2208
2209
2210
2211
2212
2213
2214
2215
2216
2217
2218
2219
2220
2221
2222
2223
2224
2225
2226
2227
2228
2229
2230
2231
2232
2233
2234
2235
2236
2237
2238
2239
2240
2241
2242
2243
2244
2245
2246
2247
2248
2249
2250
2251
2252
2253
2254
2255
2256
2257
2258
2259
2260
2261
2262
2263
2264
2265
2266
2267
2268
2269
2270
2271
2272
2273
2274
2275
2276
2277
2278
2279
2280
2281
2282
2283
2284
2285
2286
2287
2288
2289
2290
2291
2292
2293
2294
2295
2296
2297
2298
2299
2300
2301
2302
2303
2304
2305
2306
2307
2308
2309
2310
2311
/* Copyright 2018 The TensorFlow Authors. All Rights Reserved.

Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at

    http://www.apache.org/licenses/LICENSE-2.0

Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/

#ifndef TENSORFLOW_COMPILER_XLA_CLIENT_XLA_BUILDER_H_
#define TENSORFLOW_COMPILER_XLA_CLIENT_XLA_BUILDER_H_

#include <map>
#include <string>
#include <type_traits>
#include <utility>

#include "absl/container/flat_hash_map.h"
#include "absl/container/flat_hash_set.h"
#include "absl/strings/string_view.h"
#include "absl/types/span.h"
#include "tensorflow/compiler/xla/client/padding.h"
#include "tensorflow/compiler/xla/client/xla_computation.h"
#include "tensorflow/compiler/xla/literal.h"
#include "tensorflow/compiler/xla/literal_util.h"
#include "tensorflow/compiler/xla/service/hlo.pb.h"
#include "tensorflow/compiler/xla/service/hlo_opcode.h"
#include "tensorflow/compiler/xla/shape_util.h"
#include "tensorflow/compiler/xla/status_macros.h"
#include "tensorflow/compiler/xla/statusor.h"
#include "tensorflow/compiler/xla/types.h"
#include "tensorflow/compiler/xla/xla_data.pb.h"
#include "tensorflow/core/platform/macros.h"
#include "tensorflow/core/platform/stacktrace.h"
#include "tensorflow/core/platform/types.h"

namespace xla {

class XlaBuilder;

// This represents an instruction that has been enqueued using the XlaBuilder.
// This is used to pass to subsequent computations that depends upon the
// instruction as an operand.
class XlaOp {
 public:
  XlaOp() : handle_(-1), builder_(nullptr) {
    static_assert(std::is_trivially_destructible<XlaOp>::value,
                  "XlaOp should be trivially destructible");
  }
  ~XlaOp() = default;

  // Precondition: !IsUninitialized().
  //
  // It's very common to do foo.builder()->bar().  Without this precondition, if
  // foo.builder() is null, the call to bar will segfault at some point possibly
  // deep in the callstack when we finally dereference `this`.  The precondition
  // lets us avoid this tricky-to-debug problem.
  XlaBuilder* builder() const {
    CHECK(builder_ != nullptr);
    return builder_;
  }

  // Returns true if the XlaOp represents valid, non-erroneous value.
  bool valid() const { return handle_ >= 0; }

  // Returns true if the XlaOp was created by the XlaOp() constructor and
  // not returned by a builder.
  bool IsUninitialized() const { return builder_ == nullptr; }

  bool IsIdenticalTo(const XlaOp& rhs) const {
    return handle_ == rhs.handle_ && builder_ == rhs.builder_;
  }

  friend std::ostream& operator<<(std::ostream& out, const XlaOp& op) {
    out << op.handle();
    return out;
  }

 private:
  explicit XlaOp(XlaBuilder* builder) : handle_(-1), builder_(builder) {}
  XlaOp(int64 handle, XlaBuilder* builder)
      : handle_(handle), builder_(builder) {}

  int64 handle() const { return handle_; }

  friend class XlaBuilder;

  // < 0 means "invalid handle".
  int64 handle_;

  // Not owned. Non-null for any handle returned by XlaBuilder, even if the
  // handle is invalid.
  XlaBuilder* builder_;
};

// Arithmetic operator overloads for the XlaOp type.
XlaOp operator-(const XlaOp& x);
XlaOp operator+(const XlaOp& x, const XlaOp& y);
XlaOp operator-(const XlaOp& x, const XlaOp& y);
XlaOp operator*(const XlaOp& x, const XlaOp& y);
XlaOp operator/(const XlaOp& x, const XlaOp& y);
XlaOp operator%(const XlaOp& x, const XlaOp& y);

// Bitwise operator overloads for the XlaOp type.
XlaOp operator~(const XlaOp& x);
XlaOp operator&(const XlaOp& x, const XlaOp& y);
XlaOp operator|(const XlaOp& x, const XlaOp& y);
XlaOp operator^(const XlaOp& x, const XlaOp& y);
XlaOp operator<<(const XlaOp& x, const XlaOp& y);
// Performs a right arithmetic shift if 'x' is a signed type, otherwise performs
// a right logical shift.
XlaOp operator>>(const XlaOp& x, const XlaOp& y);

// We don't overload the relational operators (==, !=, <, <=, >, >=) because the
// semantics might be surprising since their result types are usually 'bool'.
// Further programmers may expect == to be a structural equality.
// We also choose not to overload any of the mutating operators (e.g., +=, -=)
// because the semantics might be misleading — XLA computations are immutable.

// A convenient interface for building up computations.
//
// Thread-compatible.
class XlaBuilder {
 public:
  // computation_name: name to use for the built computation.
  XlaBuilder(const string& computation_name);

  XlaBuilder(const XlaBuilder&) = delete;
  XlaBuilder& operator=(const XlaBuilder&) = delete;

  ~XlaBuilder();

  // Returns the computation name.
  const string& name() const { return name_; }

  // Sets OpMetadata that will be added to all instructions until cleared.
  //
  // OpMetadata is often applied to a series of XLA HLO instructions. As a
  // result, OpMetadata is set on the Computation Builder. All subsequent
  // instructions generated via this Computation Builder will have the same
  // OpMetadata attached until a call to ClearOpMetadata.
  void SetOpMetadata(const OpMetadata& metadata) { metadata_ = metadata; }

  // Clears the HloMetadata state.
  void ClearOpMetadata() { metadata_.Clear(); }

  // Sets an OpSharding that will be attached to all instructions until cleared.
  void SetSharding(const OpSharding& sharding) { sharding_ = sharding; }

  // Clears the sharding. Ops will be sharded according to the default placement
  // policy.
  void ClearSharding() { sharding_ = absl::nullopt; }

  // Returns the OpSharding that will be attached to all instructions.
  const absl::optional<OpSharding>& sharding() const { return sharding_; }

  // Sets the builder to a mode where it will die immediately when an error is
  // encountered, rather than producing it in a deferred fashion when Build() is
  // called (which is the default).
  void set_die_immediately_on_error(bool enabled) {
    die_immediately_on_error_ = enabled;
  }

  // Default dimension numbers used for a 2D convolution.
  static constexpr int64 kConvBatchDimension = 0;
  static constexpr int64 kConvFeatureDimension = 1;
  static constexpr int64 kConvFirstSpatialDimension = 2;
  static constexpr int64 kConvSecondSpatialDimension = 3;
  static constexpr int64 kConvKernelOutputDimension = 0;
  static constexpr int64 kConvKernelInputDimension = 1;
  static constexpr int64 kConvKernelFirstSpatialDimension = 2;
  static constexpr int64 kConvKernelSecondSpatialDimension = 3;

  // Creates a default ConvolutionDimensionNumbers. For a 2D convolution, for
  // the input operand {batch, feature, height, width} = {0, 1, 2, 3} and for
  // the kernel operand
  // {output_feature, input_feature, height, width} = {0, 1, 2, 3}.
  static ConvolutionDimensionNumbers CreateDefaultConvDimensionNumbers(
      int num_spatial_dims = 2);

  // Returns an error if the convolution dimension numbers have conflicts.
  static Status Validate(const ConvolutionDimensionNumbers& dnum);

  // Returns a new XlaBuilder whose resultant Computation is used only by this
  // XlaBuilder. The sub-XlaBuilder has the same die_immediately_on_error
  // behavior as the parent.
  std::unique_ptr<XlaBuilder> CreateSubBuilder(const string& computation_name);

  // Builds the computation with the requested operations, or returns a non-ok
  // status. Note that all ops that have been enqueued will be moved to the
  // computation being returned. The root of the computation will be the last
  // added operation.
  StatusOr<XlaComputation> Build();

  // Overload of Build which specifies a particular root instruction for the
  // computation.
  StatusOr<XlaComputation> Build(XlaOp root);

  // Builds the computation with the requested operations, or notes an error in
  // the parent XlaBuilder and returns an empty computation if building failed.
  // This function is intended to be used where the returned XlaComputation is
  // only used by the parent XlaBuilder and hence further operation on the
  // returned XlaComputation will simply be error'ed out if an error occurred
  // while building this computation. If the built computation is to be used by
  // a XlaBuilder other than the parent XlaBuilder then Build() should be used
  // instead.
  XlaComputation BuildAndNoteError();

  // Returns a subgraph that roots on the given root. If the root is not a
  // compile-time constant (see `IsConstant`), returns an error.
  //
  // This will copy the needed ops/computations to the subgraph.
  StatusOr<XlaComputation> BuildConstantSubGraph(const XlaOp& root_op) const;

  // Returns the first error that was encountered while building the
  // computation. When an error is encountered, by default we return a vacuous
  // XlaOp and inform the user of the error that occurred while
  // building the computation when they make a final call to Build().
  //
  // See also set_die_immediately_on_error().
  Status first_error() const { return first_error_; }

  // Returns the shape of the given op.
  StatusOr<Shape> GetShape(const XlaOp& op) const;

  // Returns the (inferred) result for the current computation's shape. This
  // assumes the root instruction is the last added instruction.
  StatusOr<ProgramShape> GetProgramShape() const;

  // Returns the (inferred) result for the current computation's shape using the
  // given operation as the root.
  StatusOr<ProgramShape> GetProgramShape(XlaOp root) const;

  // Reports an error to the builder, by
  // * storing it internally and capturing a backtrace if it's the first error
  //   (this deferred value will be produced on the call to
  //    Build()/GetShape()/...)
  // * dying if die_immediately_on_error_ is true.
  // Returns an XlaOp with an invalid handle but a valid builder. This value can
  // be returned in place of a value in APIs that return an XlaOp.
  XlaOp ReportError(const Status& error);

  // A helper function that converts a StatusOr<XlaOp> into an XlaOp.
  // If the Status was an error, reports the error to builder and returns an
  // invalid XlaOp handle.
  XlaOp ReportErrorOrReturn(const StatusOr<XlaOp>& op);

  // A helper function that runs a function that returns a StatusOr<XlaOp> and
  // returns an XlaOp.
  XlaOp ReportErrorOrReturn(const std::function<StatusOr<XlaOp>()>& op_creator);

  // Returns true if 'operand' is a compile-time constant. A compile-time
  // constant does not depend on any parameters, or on stateful operators such
  // as `RngNormal` or `Infeed`.
  //
  // This tests whether a computation is a compile-time constant without
  // evaluating the computation.
  StatusOr<bool> IsConstant(const XlaOp& operand) const;

 private:
  // Build helper which takes the id of the root operation..
  StatusOr<XlaComputation> Build(int64 root_id);

  // Enqueues a "retrieve parameter value" instruction for a parameter that was
  // passed to the computation.
  XlaOp Parameter(int64 parameter_number, const Shape& shape,
                  const string& name);

  // Enqueues a constant with the value of the given literal onto the
  // computation.
  XlaOp ConstantLiteral(const LiteralSlice& literal);

  // Enqueues a constant onto the computation. Methods are templated on the
  // native host type (NativeT) which corresponds to a specific XLA
  // PrimitiveType as given in the following table:
  //
  //  Native Type   PrimitiveType
  // -----------------------------
  //   bool           PRED
  //   int32          S32
  //   int64          S64
  //   uint32         U32
  //   uint64         U64
  //   float          F32
  //   double         F64
  //
  // Note: not all primitive types defined in xla_data.proto have a
  // corresponding native type yet.
  template <typename NativeT>
  XlaOp ConstantR0(NativeT value);
  template <typename NativeT>
  XlaOp ConstantR1(absl::Span<const NativeT> values);
  XlaOp ConstantR1(const tensorflow::core::Bitmap& values);
  template <typename NativeT>
  XlaOp ConstantR2(
      std::initializer_list<std::initializer_list<NativeT>> values);
  template <typename NativeT>
  XlaOp ConstantFromArrayWithLayout(const Array<NativeT>& values,
                                    const Layout& layout);
  template <typename NativeT>
  XlaOp ConstantFromArray(const Array<NativeT>& values);
  template <typename NativeT>
  XlaOp ConstantR2FromArray2DWithLayout(const Array2D<NativeT>& values,
                                        const Layout& layout);
  template <typename NativeT>
  XlaOp ConstantR2FromArray2D(const Array2D<NativeT>& values);
  template <typename NativeT>
  XlaOp ConstantR3FromArray3DWithLayout(const Array3D<NativeT>& values,
                                        const Layout& layout);
  template <typename NativeT>
  XlaOp ConstantR3FromArray3D(const Array3D<NativeT>& values);
  template <typename NativeT>
  XlaOp ConstantR4FromArray4DWithLayout(const Array4D<NativeT>& values,
                                        const Layout& layout);
  template <typename NativeT>
  XlaOp ConstantR4FromArray4D(const Array4D<NativeT>& values);

  // Enqueues a rank one constant (vector) onto the computation. The vector has
  // size 'length' and every element has the value 'value'.
  template <typename NativeT>
  XlaOp ConstantR1(int64 length, NativeT value);

  // Adds dimensions to an array by duplicating the data in the array.
  //
  // The new dimensions are inserted on the left, i.e. if
  // broadcast_sizes has values {a0, ..., aN} and the operand shape
  // has dimensions {b0, ..., bM} then the shape of the output has
  // dimensions {a0, ..., aN, b0, ..., bM}.
  //
  // The new dimensions index into copies of the operand, i.e.
  //
  //   output[i0, ..., iN, j0, ..., jM] = operand[j0, ..., jM]
  XlaOp Broadcast(const XlaOp& operand,
                  absl::Span<const int64> broadcast_sizes);

  // Performs in-dimension-style broadcast.
  //
  // Operand specifies the input to be broadcast. "shape" is expected output
  // shape. "broadcast_dimensions" are the dimensions to be broadcasting into.
  // Dimension numbers in broadcast_dimensions map to individual dimensions
  // of the operand, and specify what dimension of the output shape they
  // should be broadcast.
  // e.g.
  // Say operand = [1, 2], i.e., a 1D tensor with 2 elements.
  // and dimension of shape is [2,2].
  // Specifying {1} as brodcast_dimension will generate output
  // [1 , 2]
  // [1 , 2]
  // On the other hand, specifying {0} as broadcast_dimension
  // will generate output
  // [1 , 1]
  // [2 , 2]
  XlaOp BroadcastInDim(const XlaOp& operand, const Shape& shape,
                       const absl::Span<const int64> broadcast_dimensions);

  // Enqueues a pad operation onto the computation that pads the given value on
  // the edges as well as between the elements of the input. padding_config
  // specifies the padding amount for each dimension.
  XlaOp Pad(const XlaOp& operand, const XlaOp& padding_value,
            const PaddingConfig& padding_config);

  // Enqueues an operation onto the computation that flattens the operand based
  // on the dimension order (major/slowest-varying to minor/fastest-varying)
  // given, followed by reshaping it into the shape with the given dimension
  // sizes (also major to minor). Conceptually, this is a limited form of
  // "shape casting".
  XlaOp Reshape(const XlaOp& operand, absl::Span<const int64> dimensions,
                absl::Span<const int64> new_sizes);

  // Enqueues an operation onto the computation that collapses the operand, from
  // first to last dimension (C order), then reshapes it to the given dimension
  // sizes. Conceptually, this is a limited form of "shape casting".
  XlaOp Reshape(const XlaOp& operand, absl::Span<const int64> new_sizes);

  // Wrapper for Reshape.
  // Enqueues an operation to collapse the provided dimensions; e.g. an
  // operand with dimensions {x=256, y=2, z=2, p=32} can be collapsed to
  // {x=1024, y=32} by collapsing dims {0, 1, 2}. Collapsing dimensions must
  // be a consecutive, in-order subsequence of the operand dimensions.
  //
  // Note that collapsing a single dimension does nothing:
  //
  //    {256} collapsing {0} => {256}
  //    {1} collapsing {0} => {1}
  //
  // Collapsing multiple dimensions produces a single result dimension:
  //
  //    {256, 2} collapsing {0,1} => {512}
  //    {256, 2, 3} collapsing {0,1} => {512, 3}
  //
  // This could potentially cause data to be moved -- it provides a more
  // structured form of reshaping than an arbitrary Reshape operation.
  XlaOp Collapse(const XlaOp& operand, absl::Span<const int64> dimensions);

  // Enqueues a slice operation onto the computation that slices the operand
  // from the start indices to the limit indices; e.g.
  //
  //        x
  //   [ 0 1 2 3 ]
  // y [ 4 5 6 7 ] => slice(start={1, 1}, limit={2, 3}) => [ 5 6 ]
  //   [ 8 9 a b ]
  //
  // Note that "limit" means up-to-but-not-including; i.e. [start, limit) in 1D
  // range notation.
  // The strides parameter determines the stride over the slice
  XlaOp Slice(const XlaOp& operand, absl::Span<const int64> start_indices,
              absl::Span<const int64> limit_indices,
              absl::Span<const int64> strides);

  // Enqueues a slice operation in a given dimension, taking all other
  // dimensions as they are; e.g. if dimno is 1 from start_index 2 to
  // limit_index 4 by 1, and the shape is f32[7,8,9], this call is short-hand
  // for:
  //
  //  array[:, 2:4:1, :]
  XlaOp SliceInDim(const XlaOp& operand, int64 start_index, int64 limit_index,
                   int64 stride, int64 dimno);

  // Enqueues a slice operation onto the computation that slices the 'operand'
  // from dynamic start indices which are passed in 'start_indices'.
  // The size of the slice in each dimension is passed in 'slice_sizes',
  // which specify the end point of exclusive slice intervals in each
  // dimension [start, start + size).
  // The shape of 'start_indices' must be rank == 1, with dimension size
  // equal to the rank of the 'operand'.
  // Slice index calculations are computed modulo input dimension sizes to
  // prevent dynamic start indices from generating out-of-bound array accesses.
  XlaOp DynamicSlice(const XlaOp& operand, const XlaOp& start_indices,
                     absl::Span<const int64> slice_sizes);

  // Enqueues a dynamic update slice operation onto the computation, which
  // updates a slice of 'operand' with 'update' at dynamic 'start_indices'.
  // The shape of 'update' determines the shape of the slice of 'operand'
  // which is updated.
  // The indices specified in 'start_indices' specify the offset of the slice
  // of 'operand' which is updated.
  //
  //               update = {10, 11} // calculated at runtime.
  //   [1 2 3]     start  = {1, 1}   // calculated at runtime.  [1 2  3 ]
  //   [4 5 6]  => DynamicUpdateslice(data, update, start)   => [4 10 11]
  //   [7 8 9]                                                  [7 8  9 ]
  //
  // The shape of 'start_indices' must be rank == 1, with dimension size
  // equal to the rank of the 'operand'.
  // Slice index calculations are computed modulo update dimension sizes to
  // prevent dynamic start indices from generating out-of-bound array accesses.
  XlaOp DynamicUpdateSlice(const XlaOp& operand, const XlaOp& update,
                           const XlaOp& start_indices);

  // Enqueues a concatenate instruction onto the computation. 'operands' must
  // have >= 1 entry.
  XlaOp ConcatInDim(absl::Span<const XlaOp> operands, int64 dimension);

  // Enqueue a tracing operation onto the computation; the computation will emit
  // a logging message with the operand.
  void Trace(const string& tag, const XlaOp& operand);

  // Enqueues a conditional-move-like select operation onto the computation;
  // predicated on pred, selects between on_true and on_false.
  XlaOp Select(const XlaOp& pred, const XlaOp& on_true, const XlaOp& on_false);

  // Enqueues a tuple-creation instruction onto the computation.
  XlaOp Tuple(absl::Span<const XlaOp> elements);

  // Enqueues a tuple-element-get instruction onto the computation.
  XlaOp GetTupleElement(const XlaOp& tuple_data, int64 index);

  // Enqueues an equal-to comparison instruction onto the computation.
  XlaOp Eq(const XlaOp& lhs, const XlaOp& rhs,
           absl::Span<const int64> broadcast_dimensions = {});

  // Enqueues a not-equal comparison instruction onto the computation.
  XlaOp Ne(const XlaOp& lhs, const XlaOp& rhs,
           absl::Span<const int64> broadcast_dimensions = {});

  // Enqueues a greater-or-equal comparison instruction onto the computation.
  XlaOp Ge(const XlaOp& lhs, const XlaOp& rhs,
           absl::Span<const int64> broadcast_dimensions = {});

  // Enqueues a greater-than comparison instruction onto the computation.
  XlaOp Gt(const XlaOp& lhs, const XlaOp& rhs,
           absl::Span<const int64> broadcast_dimensions = {});

  // Enqueues a less-than comparison instruction onto the computation.
  XlaOp Lt(const XlaOp& lhs, const XlaOp& rhs,
           absl::Span<const int64> broadcast_dimensions = {});

  // Enqueues a less-or-equal comparison instruction onto the computation.
  XlaOp Le(const XlaOp& lhs, const XlaOp& rhs,
           absl::Span<const int64> broadcast_dimensions = {});

  // Enqueues a dot instruction onto the computation.
  XlaOp Dot(const XlaOp& lhs, const XlaOp& rhs,
            const PrecisionConfig* precision_config = nullptr);

  // Enqueues a general dot instruction onto the computation.
  XlaOp DotGeneral(const XlaOp& lhs, const XlaOp& rhs,
                   const DotDimensionNumbers& dimension_numbers,
                   const PrecisionConfig* precision_config = nullptr);

  // Enqueues a convolution instruction onto the computation, which uses the
  // default convolution dimension numbers.
  XlaOp Conv(const XlaOp& lhs, const XlaOp& rhs,
             absl::Span<const int64> window_strides, Padding padding,
             int64 feature_group_count = 1,
             const PrecisionConfig* precision_config = nullptr);

  // Enqueues a convolution instruction onto the computation, with the caller
  // provided padding configuration in the format returned by MakePadding().
  XlaOp ConvWithGeneralPadding(
      const XlaOp& lhs, const XlaOp& rhs,
      absl::Span<const int64> window_strides,
      absl::Span<const std::pair<int64, int64>> padding,
      int64 feature_group_count = 1,
      const PrecisionConfig* precision_config = nullptr);

  // Enqueues a convolution instruction onto the computation, with the caller
  // provided dimension numbers configuration.
  XlaOp ConvWithGeneralDimensions(
      const XlaOp& lhs, const XlaOp& rhs,
      absl::Span<const int64> window_strides, Padding padding,
      const ConvolutionDimensionNumbers& dimension_numbers,
      int64 feature_group_count = 1,
      const PrecisionConfig* precision_config = nullptr);

  // Enqueues a convolution instruction onto the computation, with the caller
  // provided padding configuration as well as the dimension numbers.
  XlaOp ConvGeneral(const XlaOp& lhs, const XlaOp& rhs,
                    absl::Span<const int64> window_strides,
                    absl::Span<const std::pair<int64, int64>> padding,
                    const ConvolutionDimensionNumbers& dimension_numbers,
                    int64 feature_group_count = 1,
                    const PrecisionConfig* precision_config = nullptr);

  // Enqueues a convolution instruction onto the computation, with the caller
  // provided padding configuration, dilation factors and dimension numbers.
  XlaOp ConvGeneralDilated(const XlaOp& lhs, const XlaOp& rhs,
                           absl::Span<const int64> window_strides,
                           absl::Span<const std::pair<int64, int64>> padding,
                           absl::Span<const int64> lhs_dilation,
                           absl::Span<const int64> rhs_dilation,
                           const ConvolutionDimensionNumbers& dimension_numbers,
                           int64 feature_group_count = 1,
                           const PrecisionConfig* precision_config = nullptr);

  // Enqueues an FFT instruction onto the computation, of the given type and
  // with the given FFT length.
  XlaOp Fft(const XlaOp& operand, FftType fft_type,
            absl::Span<const int64> fft_length);

  // Enqueues an infeed instruction onto the computation, which writes data of
  // the given shape to the infeed buffer of the device.
  XlaOp Infeed(const Shape& shape, const string& config = "");
  XlaOp InfeedWithToken(const XlaOp& token, const Shape& shape,
                        const string& config = "");

  // Enqueues an outfeed instruction onto the computation. This instruction
  // generates outgoing data transfers for the given data.
  //
  // shape_with_layout communicates the laid out shape that we want to outfeed
  // -- if !ShapeUtil::Compatible(GetShape(operand), shape_with_layout) an error
  // will occur.
  void Outfeed(const XlaOp& operand, const Shape& shape_with_layout,
               const string& outfeed_config);
  XlaOp OutfeedWithToken(const XlaOp& operand, const XlaOp& token,
                         const Shape& shape_with_layout,
                         const string& outfeed_config);

  // Enqueues a call instruction onto the computation.
  XlaOp Call(const XlaComputation& computation,
             absl::Span<const XlaOp> operands);

  // Enqueues a custom call instruction onto the computation.
  XlaOp CustomCall(
      const string& call_target_name, absl::Span<const XlaOp> operands,
      const Shape& shape_with_layout, const string& opaque,
      absl::optional<absl::Span<const Shape>> operand_shapes_with_layout);

  // The following methods enqueue element-wise binary arithmetic operations
  // onto the computation. The shapes of the operands have to match unless one
  // of the operands is a scalar, or an explicit broadcast dimension is given
  // (see g3doc for more details).

  // Enqueues a complex compose instruction onto the computation.
  XlaOp Complex(const XlaOp& real, const XlaOp& imag,
                absl::Span<const int64> broadcast_dimensions = {});

  // Enqueues a complex conjugate instruction onto the computation.
  XlaOp Conj(const XlaOp& operand);

  // Enqueues an add instruction onto the computation.
  XlaOp Add(const XlaOp& lhs, const XlaOp& rhs,
            absl::Span<const int64> broadcast_dimensions = {});

  // Enqueues a subtract instruction onto the computation.
  XlaOp Sub(const XlaOp& lhs, const XlaOp& rhs,
            absl::Span<const int64> broadcast_dimensions = {});

  // Enqueues a multiply instruction onto the computation.
  XlaOp Mul(const XlaOp& lhs, const XlaOp& rhs,
            absl::Span<const int64> broadcast_dimensions = {});

  // Enqueues a divide instruction onto the computation.
  XlaOp Div(const XlaOp& lhs, const XlaOp& rhs,
            absl::Span<const int64> broadcast_dimensions = {});

  // Enqueues a remainder instruction onto the computation.
  XlaOp Rem(const XlaOp& lhs, const XlaOp& rhs,
            absl::Span<const int64> broadcast_dimensions = {});

  // Enqueues a max instruction onto the computation.
  XlaOp Max(const XlaOp& lhs, const XlaOp& rhs,
            absl::Span<const int64> broadcast_dimensions = {});

  // Enqueues a min instruction onto the computation.
  XlaOp Min(const XlaOp& lhs, const XlaOp& rhs,
            absl::Span<const int64> broadcast_dimensions = {});

  // Element-wise logical operators
  XlaOp And(const XlaOp& lhs, const XlaOp& rhs,
            absl::Span<const int64> broadcast_dimensions = {});

  XlaOp Or(const XlaOp& lhs, const XlaOp& rhs,
           absl::Span<const int64> broadcast_dimensions = {});

  XlaOp Xor(const XlaOp& lhs, const XlaOp& rhs,
            absl::Span<const int64> broadcast_dimensions = {});

  XlaOp Not(const XlaOp& operand);

  XlaOp ShiftLeft(const XlaOp& lhs, const XlaOp& rhs,
                  absl::Span<const int64> broadcast_dimensions = {});
  XlaOp ShiftRightArithmetic(const XlaOp& lhs, const XlaOp& rhs,
                             absl::Span<const int64> broadcast_dimensions = {});
  XlaOp ShiftRightLogical(const XlaOp& lhs, const XlaOp& rhs,
                          absl::Span<const int64> broadcast_dimensions = {});

  // Reduces an array among the provided dimensions, given "computation" as a
  // reduction operator.
  XlaOp Reduce(const XlaOp& operand, const XlaOp& init_value,
               const XlaComputation& computation,
               absl::Span<const int64> dimensions_to_reduce);

  // Reduces several arrays simultaneously among the provided dimensions, given
  // "computation" as a reduction operator.
  XlaOp Reduce(absl::Span<const XlaOp> operands,
               absl::Span<const XlaOp> init_values,
               const XlaComputation& computation,
               absl::Span<const int64> dimensions_to_reduce);

  // Convenience wrapper around the above that reduces all the dimensions in the
  // operand shape.
  XlaOp ReduceAll(const XlaOp& operand, const XlaOp& init_value,
                  const XlaComputation& computation);

  // Enqueues a windowed reduce instruction onto the computation.
  XlaOp ReduceWindow(const XlaOp& operand, const XlaOp& init_value,
                     const XlaComputation& computation,
                     absl::Span<const int64> window_dimensions,
                     absl::Span<const int64> window_strides, Padding padding);

  // As ReduceWindow(), but the padding is given in the format
  // returned by MakePadding().
  XlaOp ReduceWindowWithGeneralPadding(
      const XlaOp& operand, const XlaOp& init_value,
      const XlaComputation& computation,
      absl::Span<const int64> window_dimensions,
      absl::Span<const int64> window_strides,
      absl::Span<const int64> base_dilations,
      absl::Span<const int64> window_dilations,
      absl::Span<const std::pair<int64, int64>> padding);

  // Returns the sum of the operand value within each subgroup of replicas. All
  // replicas supply one input to the sum and all replicas receive the resulting
  // sum for each subgroup.
  XlaOp CrossReplicaSum(const XlaOp& operand,
                        absl::Span<const ReplicaGroup> replica_groups = {});

  // Enqueues an operation that do an AllReduce of the operand cross cores. Here
  // AllReduce means doing a reduction on the input operand cross cores and then
  // broadcasting the reduction result to those cores. The reduction function is
  // defined by `computation`, which should be a commutative computation on
  // scalars, e.g., add, min, or max. The way that AllReduce is applied is
  // configured by:
  //
  // - `replica_groups`: each ReplicaGroup contains a list of replica id. If
  // empty, all replicas belong to one group. Allreduce will be applied within
  // subgroups. For example, we have 4 replicas, then
  // replica_groups={{0,2},{1,3}} means, replica 0 and 2 are in subgroup 0,
  // replica 1 and 3 are in subgroup 1.
  //
  // - `channel_id`: for Allreduce nodes from different modules, if they have
  // the same channel_id, they will be 'Allreduce'd. If empty, Allreduce will
  // not be applied cross modules.
  //
  // TODO(b/79737069): Rename this to AllReduce when it's ready to use.
  XlaOp CrossReplicaSum(
      const XlaOp& operand, const XlaComputation& computation,
      absl::Span<const ReplicaGroup> replica_groups = {},
      const absl::optional<ChannelHandle>& channel_id = absl::nullopt);

  // Enqueues an operation that do an Alltoall of the operand cross cores.
  XlaOp AllToAll(const XlaOp& operand, int64 split_dimension,
                 int64 concat_dimension, int64 split_count,
                 const std::vector<ReplicaGroup>& replica_groups);

  // Enqueues an operation that do an CollectivePermute of the operand cross
  // cores.
  XlaOp CollectivePermute(
      const XlaOp& operand,
      const std::vector<std::pair<int64, int64>>& source_target_pairs);

  // Enqueues an operation that scatters the `source` array to the selected
  // indices of each window.
  XlaOp SelectAndScatter(const XlaOp& operand, const XlaComputation& select,
                         absl::Span<const int64> window_dimensions,
                         absl::Span<const int64> window_strides,
                         Padding padding, const XlaOp& source,
                         const XlaOp& init_value,
                         const XlaComputation& scatter);

  // As SelectAndScatter(), but the padding is given in the format
  // returned by MakePadding().
  XlaOp SelectAndScatterWithGeneralPadding(
      const XlaOp& operand, const XlaComputation& select,
      absl::Span<const int64> window_dimensions,
      absl::Span<const int64> window_strides,
      absl::Span<const std::pair<int64, int64>> padding, const XlaOp& source,
      const XlaOp& init_value, const XlaComputation& scatter);

  // Enqueues an abs instruction onto the computation.
  XlaOp Abs(const XlaOp& operand);

  // Enqueues a atan2 instruction onto the computation.
  XlaOp Atan2(const XlaOp& y, const XlaOp& x,
              absl::Span<const int64> broadcast_dimensions = {});

  // Enqueues an exp instruction onto the computation.
  XlaOp Exp(const XlaOp& operand);

  // Enqueues an expm1 instruction onto the computation.
  XlaOp Expm1(const XlaOp& operand);

  // Enqueues a floor instruction onto the computation.
  XlaOp Floor(const XlaOp& operand);

  // Enqueues a ceil instruction onto the computation.
  XlaOp Ceil(const XlaOp& operand);

  // Enqueues a round instruction onto the computation, rounding to nearest even
  // with half-way cases rounding away from zero.
  XlaOp Round(const XlaOp& operand);

  // Enqueues an log instruction (natural logarithm) onto the computation.
  XlaOp Log(const XlaOp& operand);

  // Enqueues an log1p instruction (log(x+1)) onto the computation.
  XlaOp Log1p(const XlaOp& operand);

  // Enqueues a sign instruction onto the computation.
  XlaOp Sign(const XlaOp& operand);

  // Enqueues a count leading zeros instruction onto the computation.
  XlaOp Clz(const XlaOp& operand);

  // Enqueues a cosine instruction onto the computation.
  XlaOp Cos(const XlaOp& operand);

  // Enqueues a sine instruction onto the computation.
  XlaOp Sin(const XlaOp& operand);

  // Enqueues a tanh instruction onto the computation.
  XlaOp Tanh(const XlaOp& operand);

  // Enqueues a real-part instruction onto the computation.
  XlaOp Real(const XlaOp& operand);

  // Enqueues an imaginary-part instruction onto the computation.
  XlaOp Imag(const XlaOp& operand);

  // Enqueues a lhs^rhs computation onto the computation.
  XlaOp Pow(const XlaOp& lhs, const XlaOp& rhs,
            absl::Span<const int64> broadcast_dimensions = {});

  // Enqueues an operator that tests if the operand's values are finite, i.e.,
  // not Inf or NaN. Defined only for floating-point types. Returns an array of
  // booleans with the same shape where entries are true iff the corresponding
  // entry was NaN.
  XlaOp IsFinite(const XlaOp& operand);

  // Enqueues an iota operation onto the computation.
  XlaOp Iota(const Shape& shape, int64 iota_dimension);

  // Enqueues a rank-1 iota operation onto the computation.
  XlaOp Iota(PrimitiveType type, int64 size);

  // Enqueues a convert instruction onto the computation that changes the
  // element type of the operand array to primitive_type.
  XlaOp ConvertElementType(const XlaOp& operand,
                           PrimitiveType new_element_type);

  // Enqueues a no-op instruction onto the computation that changes
  // the element type of the operand array to primitive_type. The
  // bit-widths of the source and destination element types must be
  // identical.
  XlaOp BitcastConvertType(const XlaOp& operand,
                           PrimitiveType new_element_type);

  // Enqueues a negate instruction onto the computation.
  XlaOp Neg(const XlaOp& operand);

  // Enqueues a transpose instruction onto the computation.
  XlaOp Transpose(const XlaOp& operand, absl::Span<const int64> permutation);

  // Enqueues a reverse instruction onto the computation. The order of the
  // elements in the given dimensions is reversed (i.e., the element at index i
  // is moved to index dimension_size - 1 - i).
  XlaOp Rev(const XlaOp& operand, absl::Span<const int64> dimensions);

  // Enqueues a sort (as increasing order) instruction onto the computation.
  // If only keys are provided:
  // * If the keys are an rank-1 tensor (an array), the result is a sorted array
  // of keys, in ascending order.
  // * If the keys have higher rank, the keys are sorted along the provided
  // dimension. For example, for a rank-2 tensor (a matrix) of keys, a dimension
  // value of 0 will indepenently sort every column, and a dimension value of 1
  // will independently sort each row. If no dimension number is provided, then
  // the last dimension is chosen by default.
  //
  // If both keys and values are provided:
  // * The keys and the values must tensors with the same dimensions. The
  // element types of the tensors may be different.
  // * The result is a tuple that consists of a sorted tensor of keys (along the
  // provided dimension, as above) as the first element, and a tensor with their
  // corresponding values as the second element.
  XlaOp Sort(XlaOp keys, absl::optional<XlaOp> values = absl::nullopt,
             int64 dimension = -1);

  // Enqueues a clamp instruction onto the computation.
  XlaOp Clamp(const XlaOp& min, const XlaOp& operand, const XlaOp& max);

  // Enqueues a map instruction onto the computation.
  XlaOp Map(absl::Span<const XlaOp> operands, const XlaComputation& computation,
            absl::Span<const int64> dimensions,
            absl::Span<const XlaOp> static_operands = {});

  // Enqueues a N(mu, sigma) random number generation instruction onto the
  // computation.
  XlaOp RngNormal(const XlaOp& mu, const XlaOp& sigma, const Shape& shape);

  // Enqueues a U(a, b) random number generation instruction onto the
  // computation. Returns values in the semi-open interval [a, b).
  XlaOp RngUniform(const XlaOp& a, const XlaOp& b, const Shape& shape);

  // Enqueues a while node onto the computation.
  XlaOp While(const XlaComputation& condition, const XlaComputation& body,
              const XlaOp& init);

  // Enqueues a conditional node onto the computation.
  XlaOp Conditional(const XlaOp& predicate, const XlaOp& true_operand,
                    const XlaComputation& true_computation,
                    const XlaOp& false_operand,
                    const XlaComputation& false_computation);

  // Enqueues a ReducePrecision node onto the computation.
  XlaOp ReducePrecision(const XlaOp& operand, const int exponent_bits,
                        const int mantissa_bits);

  // Enqueues a Gather node onto the computation.
  XlaOp Gather(const XlaOp& input, const XlaOp& start_indices,
               const GatherDimensionNumbers& dimension_numbers,
               absl::Span<const int64> slice_sizes);

  // Enqueues a Scatter node onto the computation.
  XlaOp Scatter(const XlaOp& input, const XlaOp& scatter_indices,
                const XlaOp& updates, const XlaComputation& update_computation,
                const ScatterDimensionNumbers& dimension_numbers);

  // Enqueues a Send node onto the computation for device-to-device
  // communication, to send the given operand to a Recv instruction that shares
  // the same channel handle.
  void Send(const XlaOp& operand, const ChannelHandle& handle);
  XlaOp SendWithToken(const XlaOp& operand, const XlaOp& token,
                      const ChannelHandle& handle);

  // Enqueues a Send node which sends data to the host.
  XlaOp SendToHost(const XlaOp& operand, const XlaOp& token,
                   const Shape& shape_with_layout, const ChannelHandle& handle);

  // Enqueues a Recv node which receives data from the host.
  XlaOp RecvFromHost(const XlaOp& token, const Shape& shape,
                     const ChannelHandle& handle);

  // Enqueues an AfterAll operation with no operands producing a token-shaped
  // value.
  XlaOp CreateToken();

  // Enqueues an AfterAll operation with no operands producing a token-shaped
  // value.
  XlaOp AfterAll(absl::Span<const XlaOp> tokens);

  // Enqueues a Recv node onto the computation. The data comes from a Send
  // instruction that shares the same channel handle and its shape must
  // be the same as the given shape.
  XlaOp Recv(const Shape& shape, const ChannelHandle& handle);
  XlaOp RecvWithToken(const XlaOp& token, const Shape& shape,
                      const ChannelHandle& handle);

  // Normalizes operand across spatial and batch dimensions for each feature.
  //
  // Returns a tuple (normalized, batch_mean, batch_var) where `normalized`
  // is the normalized result and batch_mean and batch_var are the mean and
  // variance, respectively, across batch for the operand.
  XlaOp BatchNormTraining(const XlaOp& operand, const XlaOp& scale,
                          const XlaOp& offset, float epsilon,
                          int64 feature_index);

  // Normalizes operand across spatial and batch dimensions for each feature.
  //
  // `BatchNormInference` is equivalent to calling `BatchNormTraining` without
  // computing `mean` and `variance` for each batch inside the operation. It
  // uses the input `mean` and `variance` instead as estimated values. The
  // purpose of this op is to reduce latency in inference, hence the name
  // `BatchNormInference`.
  //
  // The output has the same shape as `operand`, and contains the normalized
  // values for each batch.
  XlaOp BatchNormInference(const XlaOp& operand, const XlaOp& scale,
                           const XlaOp& offset, const XlaOp& mean,
                           const XlaOp& variance, float epsilon,
                           int64 feature_index);

  // Calculates the gradients of a batch norm op.
  //
  // The inputs `batch_mean` and `batch_var` represent the mean and variance
  // across the batch.
  //
  // Returns a tuple of three elements:
  //   - grad_operand: Gradient with respect to input `operand`
  //   - grad_offset: Gradient with respect to input `offset`
  //   - grad_scale: Gradient with respect to input `scale`
  XlaOp BatchNormGrad(const XlaOp& operand, const XlaOp& scale,
                      const XlaOp& batch_mean, const XlaOp& batch_var,
                      const XlaOp& grad_output, float epsilon,
                      int64 feature_index);

  StatusOr<XlaOp> AddInstruction(HloInstructionProto&& instr, HloOpcode opcode,
                                 absl::Span<const XlaOp> operands = {});

  void AddCalledComputation(const XlaComputation& computation,
                            HloInstructionProto* instr);

  StatusOr<const HloInstructionProto*> LookUpInstruction(const XlaOp& op) const;
  StatusOr<const HloInstructionProto*> LookUpInstructionByHandle(
      int64 handle) const;

  // Internal helper method that does the building for an arbitrary unary op.
  XlaOp UnaryOp(HloOpcode unop, const XlaOp& operand);

  // Internal helper method that does the building for an arbitrary binary op.
  // broadcast_dimensions specifies which dimensions to use for broadcasting
  // when the operation is between tensors of different ranks.
  XlaOp BinaryOp(HloOpcode binop, const XlaOp& lhs, const XlaOp& rhs,
                 absl::Span<const int64> broadcast_dimensions);

  // Internal helper method that does the building for an arbitrary ternary op.
  XlaOp TernaryOp(HloOpcode triop, const XlaOp& lhs, const XlaOp& rhs,
                  const XlaOp& ehs);

  XlaOp RngOp(RandomDistribution distribution,
              absl::Span<const XlaOp> parameters, const Shape& shape);

  StatusOr<XlaOp> InDimBroadcast(const Shape& shape, const XlaOp& operand,
                                 absl::Span<const int64> broadcast_dimensions);

  // Internal helper method that creates a sequence of instructions that
  // performs an explicit broadcast of the operand to the target shape.
  StatusOr<XlaOp> AddBroadcastSequence(const Shape& output_shape,
                                       const XlaOp& operand);

  // Internal helper method for creating a Reshape op with the already inferred
  // shape.
  StatusOr<XlaOp> Reshape(const Shape& shape, const XlaOp& operand);

  // Returns the (inferred) result for the program shape using the given root.
  StatusOr<ProgramShape> GetProgramShape(int64 root_id) const;

  // Returns shapes for the operands.
  StatusOr<std::vector<Shape>> GetOperandShapes(
      absl::Span<const XlaOp> operands) const;

  // A visitor which checks whether an operation is a compile-time constant,
  // meaning that it doesn't depend on any parameters, or on any stateful
  // operation such as `RngNormal` or `Infeed`. The visitor walks the
  // computation starting at a given operation and sets is_constant to false iff
  // a parameter or stateful operation is encountered.
  void IsConstantVisitor(const int64 op_handle, std::set<int64>* visited,
                         bool* is_constant) const;

  // Checks bounds for convolution parameters.
  Status VerifyConvolution(
      const Shape& lhs_shape, const Shape& rhs_shape,
      const ConvolutionDimensionNumbers& dimension_numbers) const;

  // Helper function for creating a Window proto from user-supplied data.
  // Returns error if the user-supplied data was invalid.
  StatusOr<Window> MakeWindow(absl::Span<const int64> window_dimensions,
                              absl::Span<const int64> window_strides,
                              absl::Span<const std::pair<int64, int64>> padding,
                              absl::Span<const int64> lhs_dilation,
                              absl::Span<const int64> rhs_dilation) const;

  string name_;  // Name to use for the built computation.

  // The first error encountered while building the computation.
  // This is OK until the first error is encountered.
  Status first_error_;

  // The saved stack trace from the point at which the first error occurred.
  tensorflow::SavedStackTrace first_error_backtrace_;

  // The instructions of this computation.
  std::vector<HloInstructionProto> instructions_;

  // A map from XlaOp::Handle to the index in the instructions_ vector where the
  // instruction is held.
  absl::flat_hash_map<int64, int64> handle_to_index_;

  // The embedded computations used by this computation. Each computation was
  // the entry computation of some XlaComputation, the key is the unique id of
  // that XlaComputation.
  std::map<int64, HloComputationProto> embedded_;

  // The unique parameter numbers.
  absl::flat_hash_set<int64> parameter_numbers_;

  // The metadata to attach to each op. This is structured as a "modal"-like
  // operation, in order to simplify client code (and not sprinkle this metadata
  // throughout the TensorFlow op kernel implementations).
  OpMetadata metadata_;

  // Sharding for this operator. This is structured as a "model"-like operation,
  // in order to simplify client code, similar to metadata_.
  absl::optional<OpSharding> sharding_;

  // Mode bit that indicates whether to die when a first error is encountered.
  bool die_immediately_on_error_ = false;

  XlaBuilder* parent_builder_{nullptr};

  friend XlaOp Parameter(XlaBuilder* builder, int64 parameter_number,
                         const Shape& shape, const string& name);
  friend XlaOp ConstantLiteral(XlaBuilder* builder,
                               const LiteralSlice& literal);
  template <typename NativeT>
  friend XlaOp ConstantR0(XlaBuilder* builder, NativeT value);
  template <typename NativeT>
  friend XlaOp ConstantR1(XlaBuilder* builder,
                          absl::Span<const NativeT> values);
  friend XlaOp ConstantR1(XlaBuilder* builder,
                          const tensorflow::core::Bitmap& values);
  template <typename NativeT>
  friend XlaOp ConstantR2(
      XlaBuilder* builder,
      std::initializer_list<std::initializer_list<NativeT>> values);
  template <typename NativeT>
  friend XlaOp ConstantFromArrayWithLayout(XlaBuilder* builder,
                                           const Array<NativeT>& values,
                                           const Layout& layout);
  template <typename NativeT>
  friend XlaOp ConstantFromArray(XlaBuilder* builder,
                                 const Array<NativeT>& values);
  template <typename NativeT>
  friend XlaOp ConstantR2FromArray2DWithLayout(XlaBuilder* builder,
                                               const Array2D<NativeT>& values,
                                               const Layout& layout);
  template <typename NativeT>
  friend XlaOp ConstantR2FromArray2D(XlaBuilder* builder,
                                     const Array2D<NativeT>& values);
  template <typename NativeT>
  friend XlaOp ConstantR3FromArray3DWithLayout(XlaBuilder* builder,
                                               const Array3D<NativeT>& values,
                                               const Layout& layout);
  template <typename NativeT>
  friend XlaOp ConstantR3FromArray3D(XlaBuilder* builder,
                                     const Array3D<NativeT>& values);
  template <typename NativeT>
  friend XlaOp ConstantR4FromArray4DWithLayout(XlaBuilder* builder,
                                               const Array4D<NativeT>& values,
                                               const Layout& layout);
  template <typename NativeT>
  friend XlaOp ConstantR4FromArray4D(XlaBuilder* builder,
                                     const Array4D<NativeT>& values);

  template <typename NativeT>
  friend XlaOp ConstantR1(XlaBuilder* builder, int64 length, NativeT value);

  friend XlaOp Broadcast(const XlaOp& operand,
                         absl::Span<const int64> broadcast_sizes);

  friend XlaOp BroadcastInDim(
      const XlaOp& operand, const Shape& shape,
      const absl::Span<const int64> broadcast_dimensions);

  friend XlaOp Pad(const XlaOp& operand, const XlaOp& padding_value,
                   const PaddingConfig& padding_config);

  friend XlaOp Reshape(const XlaOp& operand, absl::Span<const int64> dimensions,
                       absl::Span<const int64> new_sizes);

  friend XlaOp Reshape(const XlaOp& operand, absl::Span<const int64> new_sizes);

  friend XlaOp Collapse(const XlaOp& operand,
                        absl::Span<const int64> dimensions);

  friend XlaOp Slice(const XlaOp& operand,
                     absl::Span<const int64> start_indices,
                     absl::Span<const int64> limit_indices,
                     absl::Span<const int64> strides);

  friend XlaOp SliceInDim(const XlaOp& operand, int64 start_index,
                          int64 limit_index, int64 stride, int64 dimno);

  friend XlaOp DynamicSlice(const XlaOp& operand, const XlaOp& start_indices,
                            absl::Span<const int64> slice_sizes);

  friend XlaOp DynamicUpdateSlice(const XlaOp& operand, const XlaOp& update,
                                  const XlaOp& start_indices);

  friend XlaOp ConcatInDim(XlaBuilder* builder,
                           absl::Span<const XlaOp> operands, int64 dimension);

  friend void Trace(const string& tag, const XlaOp& operand);

  friend XlaOp Select(const XlaOp& pred, const XlaOp& on_true,
                      const XlaOp& on_false);
  friend XlaOp Tuple(XlaBuilder* builder, absl::Span<const XlaOp> elements);
  friend XlaOp GetTupleElement(const XlaOp& tuple_data, int64 index);
  friend XlaOp Eq(const XlaOp& lhs, const XlaOp& rhs,
                  absl::Span<const int64> broadcast_dimensions);
  friend XlaOp Ne(const XlaOp& lhs, const XlaOp& rhs,
                  absl::Span<const int64> broadcast_dimensions);
  friend XlaOp Ge(const XlaOp& lhs, const XlaOp& rhs,
                  absl::Span<const int64> broadcast_dimensions);
  friend XlaOp Gt(const XlaOp& lhs, const XlaOp& rhs,
                  absl::Span<const int64> broadcast_dimensions);
  friend XlaOp Lt(const XlaOp& lhs, const XlaOp& rhs,
                  absl::Span<const int64> broadcast_dimensions);
  friend XlaOp Le(const XlaOp& lhs, const XlaOp& rhs,
                  absl::Span<const int64> broadcast_dimensions);
  friend XlaOp Dot(const XlaOp& lhs, const XlaOp& rhs,
                   const PrecisionConfig* precision_config);
  friend XlaOp DotGeneral(const XlaOp& lhs, const XlaOp& rhs,
                          const DotDimensionNumbers& dimension_number,
                          const PrecisionConfig* precision_config);
  friend XlaOp Conv(const XlaOp& lhs, const XlaOp& rhs,
                    absl::Span<const int64> window_strides, Padding padding,
                    int64 feature_group_count,
                    const PrecisionConfig* precision_config);
  friend XlaOp ConvWithGeneralPadding(
      const XlaOp& lhs, const XlaOp& rhs,
      absl::Span<const int64> window_strides,
      absl::Span<const std::pair<int64, int64>> padding,
      int64 feature_group_count, const PrecisionConfig* precision_config);
  friend XlaOp ConvWithGeneralDimensions(
      const XlaOp& lhs, const XlaOp& rhs,
      absl::Span<const int64> window_strides, Padding padding,
      const ConvolutionDimensionNumbers& dimension_numbers,
      int64 feature_group_count, const PrecisionConfig* precision_config);
  friend XlaOp ConvGeneral(const XlaOp& lhs, const XlaOp& rhs,
                           absl::Span<const int64> window_strides,
                           absl::Span<const std::pair<int64, int64>> padding,
                           const ConvolutionDimensionNumbers& dimension_numbers,
                           int64 feature_group_count,
                           const PrecisionConfig* precision_config);
  friend XlaOp ConvGeneralDilated(
      const XlaOp& lhs, const XlaOp& rhs,
      absl::Span<const int64> window_strides,
      absl::Span<const std::pair<int64, int64>> padding,
      absl::Span<const int64> lhs_dilation,
      absl::Span<const int64> rhs_dilation,
      const ConvolutionDimensionNumbers& dimension_numbers,
      int64 feature_group_count, const PrecisionConfig* precision_config);
  friend XlaOp Fft(const XlaOp& operand, FftType fft_type,
                   absl::Span<const int64> fft_length);
  friend XlaOp Infeed(XlaBuilder* builder, const Shape& shape,
                      const string& config);
  friend void Outfeed(const XlaOp& operand, const Shape& shape_with_layout,
                      const string& outfeed_config);
  friend XlaOp Call(XlaBuilder* builder, const XlaComputation& computation,
                    absl::Span<const XlaOp> operands);
  friend XlaOp CustomCall(XlaBuilder* builder, const string& call_target_name,
                          absl::Span<const XlaOp> operands, const Shape& shape,
                          const string& opaque);
  friend XlaOp CustomCallWithLayout(
      XlaBuilder* builder, const string& call_target_name,
      absl::Span<const XlaOp> operands, const Shape& shape_with_layout,
      absl::Span<const Shape> operand_shapes_with_layout, const string& opaque);
  friend XlaOp Complex(const XlaOp& real, const XlaOp& imag,
                       absl::Span<const int64> broadcast_dimensions);
  friend XlaOp Conj(const XlaOp& operand);
  friend XlaOp Add(const XlaOp& lhs, const XlaOp& rhs,
                   absl::Span<const int64> broadcast_dimensions);
  friend XlaOp Sub(const XlaOp& lhs, const XlaOp& rhs,
                   absl::Span<const int64> broadcast_dimensions);
  friend XlaOp Mul(const XlaOp& lhs, const XlaOp& rhs,
                   absl::Span<const int64> broadcast_dimensions);
  friend XlaOp Div(const XlaOp& lhs, const XlaOp& rhs,
                   absl::Span<const int64> broadcast_dimensions);
  friend XlaOp Rem(const XlaOp& lhs, const XlaOp& rhs,
                   absl::Span<const int64> broadcast_dimensions);
  friend XlaOp Max(const XlaOp& lhs, const XlaOp& rhs,
                   absl::Span<const int64> broadcast_dimensions);
  friend XlaOp Min(const XlaOp& lhs, const XlaOp& rhs,
                   absl::Span<const int64> broadcast_dimensions);
  friend XlaOp And(const XlaOp& lhs, const XlaOp& rhs,
                   absl::Span<const int64> broadcast_dimensions);
  friend XlaOp Or(const XlaOp& lhs, const XlaOp& rhs,
                  absl::Span<const int64> broadcast_dimensions);
  friend XlaOp Xor(const XlaOp& lhs, const XlaOp& rhs,
                   absl::Span<const int64> broadcast_dimensions);
  friend XlaOp Not(const XlaOp& operand);
  friend XlaOp ShiftLeft(const XlaOp& lhs, const XlaOp& rhs,
                         absl::Span<const int64> broadcast_dimensions);
  friend XlaOp ShiftRightArithmetic(
      const XlaOp& lhs, const XlaOp& rhs,
      absl::Span<const int64> broadcast_dimensions);
  friend XlaOp ShiftRightLogical(const XlaOp& lhs, const XlaOp& rhs,
                                 absl::Span<const int64> broadcast_dimensions);
  friend XlaOp Reduce(const XlaOp& operand, const XlaOp& init_value,
                      const XlaComputation& computation,
                      absl::Span<const int64> dimensions_to_reduce);
  friend XlaOp Reduce(XlaBuilder* builder, absl::Span<const XlaOp> operands,
                      absl::Span<const XlaOp> init_values,
                      const XlaComputation& computation,
                      absl::Span<const int64> dimensions_to_reduce);
  friend XlaOp ReduceAll(const XlaOp& operand, const XlaOp& init_value,
                         const XlaComputation& computation);
  friend XlaOp ReduceWindow(const XlaOp& operand, const XlaOp& init_value,
                            const XlaComputation& computation,
                            absl::Span<const int64> window_dimensions,
                            absl::Span<const int64> window_strides,
                            Padding padding);
  friend XlaOp ReduceWindowWithGeneralPadding(
      const XlaOp& operand, const XlaOp& init_value,
      const XlaComputation& computation,
      absl::Span<const int64> window_dimensions,
      absl::Span<const int64> window_strides,
      absl::Span<const int64> base_dilations,
      absl::Span<const int64> window_dilations,
      absl::Span<const std::pair<int64, int64>> padding);
  friend XlaOp CrossReplicaSum(const XlaOp& operand,
                               absl::Span<const ReplicaGroup> replica_groups);
  friend XlaOp CrossReplicaSum(const XlaOp& operand,
                               const XlaComputation& computation,
                               absl::Span<const ReplicaGroup> replica_groups,
                               const absl::optional<ChannelHandle>& channel_id);
  friend XlaOp AllToAll(const XlaOp& operand, int64 split_dimension,
                        int64 concat_dimension, int64 split_count,
                        const std::vector<ReplicaGroup>& replica_groups);
  friend XlaOp CollectivePermute(
      const XlaOp& operand,
      const std::vector<std::pair<int64, int64>>& source_target_pairs);
  friend XlaOp SelectAndScatter(const XlaOp& operand,
                                const XlaComputation& select,
                                absl::Span<const int64> window_dimensions,
                                absl::Span<const int64> window_strides,
                                Padding padding, const XlaOp& source,
                                const XlaOp& init_value,
                                const XlaComputation& scatter);
  friend XlaOp SelectAndScatterWithGeneralPadding(
      const XlaOp& operand, const XlaComputation& select,
      absl::Span<const int64> window_dimensions,
      absl::Span<const int64> window_strides,
      absl::Span<const std::pair<int64, int64>> padding, const XlaOp& source,
      const XlaOp& init_value, const XlaComputation& scatter);
  friend XlaOp Abs(const XlaOp& operand);
  friend XlaOp Atan2(const XlaOp& y, const XlaOp& x,
                     absl::Span<const int64> broadcast_dimensions);
  friend XlaOp Exp(const XlaOp& operand);
  friend XlaOp Expm1(const XlaOp& operand);
  friend XlaOp Floor(const XlaOp& operand);
  friend XlaOp Ceil(const XlaOp& operand);
  friend XlaOp Round(const XlaOp& operand);
  friend XlaOp Log(const XlaOp& operand);
  friend XlaOp Log1p(const XlaOp& operand);
  friend XlaOp Sign(const XlaOp& operand);
  friend XlaOp Clz(const XlaOp& operand);
  friend XlaOp Cos(const XlaOp& operand);
  friend XlaOp Sin(const XlaOp& operand);
  friend XlaOp Tanh(const XlaOp& operand);
  friend XlaOp Real(const XlaOp& operand);
  friend XlaOp Imag(const XlaOp& operand);
  friend XlaOp Pow(const XlaOp& lhs, const XlaOp& rhs,
                   absl::Span<const int64> broadcast_dimensions);
  friend XlaOp IsFinite(const XlaOp& operand);
  friend XlaOp Iota(XlaBuilder* builder, const Shape& shape,
                    int64 iota_dimension);
  friend XlaOp Iota(XlaBuilder* builder, PrimitiveType type, int64 size);
  friend XlaOp ConvertElementType(const XlaOp& operand,
                                  PrimitiveType new_element_type);
  friend XlaOp BitcastConvertType(const XlaOp& operand,
                                  PrimitiveType new_element_type);
  friend XlaOp Neg(const XlaOp& operand);
  friend XlaOp Transpose(const XlaOp& operand,
                         absl::Span<const int64> permutation);
  friend XlaOp Rev(const XlaOp& operand, absl::Span<const int64> dimensions);
  friend XlaOp Sort(XlaOp keys, absl::optional<XlaOp> values, int64 dimension);
  friend XlaOp Clamp(const XlaOp& min, const XlaOp& operand, const XlaOp& max);
  friend XlaOp Map(XlaBuilder* builder, absl::Span<const XlaOp> operands,
                   const XlaComputation& computation,
                   absl::Span<const int64> dimensions,
                   absl::Span<const XlaOp> static_operands);
  friend XlaOp RngNormal(const XlaOp& mu, const XlaOp& sigma,
                         const Shape& shape);
  friend XlaOp RngUniform(const XlaOp& a, const XlaOp& b, const Shape& shape);
  friend XlaOp While(const XlaComputation& condition,
                     const XlaComputation& body, const XlaOp& init);
  friend XlaOp Conditional(const XlaOp& predicate, const XlaOp& true_operand,
                           const XlaComputation& true_computation,
                           const XlaOp& false_operand,
                           const XlaComputation& false_computation);
  friend XlaOp ReducePrecision(const XlaOp& operand, const int exponent_bits,
                               const int mantissa_bits);
  friend XlaOp Gather(const XlaOp& input, const XlaOp& start_indices,
                      const GatherDimensionNumbers& dimension_numbers,
                      absl::Span<const int64> slice_sizes);
  friend XlaOp Scatter(const XlaOp& input, const XlaOp& scatter_indices,
                       const XlaOp& updates,
                       const XlaComputation& update_computation,
                       const ScatterDimensionNumbers& dimension_numbers);
  friend void Send(const XlaOp& operand, const ChannelHandle& handle);
  friend XlaOp Recv(XlaBuilder* builder, const Shape& shape,
                    const ChannelHandle& handle);
  friend XlaOp BatchNormTraining(const XlaOp& operand, const XlaOp& scale,
                                 const XlaOp& offset, float epsilon,
                                 int64 feature_index);
  friend XlaOp BatchNormInference(const XlaOp& operand, const XlaOp& scale,
                                  const XlaOp& offset, const XlaOp& mean,
                                  const XlaOp& variance, float epsilon,
                                  int64 feature_index);
  friend XlaOp BatchNormGrad(const XlaOp& operand, const XlaOp& scale,
                             const XlaOp& batch_mean, const XlaOp& batch_var,
                             const XlaOp& grad_output, float epsilon,
                             int64 feature_index);
  friend XlaOp SendWithToken(const XlaOp& operand, const XlaOp& token,
                             const ChannelHandle& handle);
  friend XlaOp RecvWithToken(const XlaOp& token, const Shape& shape,
                             const ChannelHandle& handle);
  friend XlaOp SendToHost(const XlaOp& operand, const XlaOp& token,
                          const Shape& shape_with_layout,
                          const ChannelHandle& handle);
  friend XlaOp RecvFromHost(const XlaOp& token, const Shape& shape,
                            const ChannelHandle& handle);
  friend XlaOp InfeedWithToken(const XlaOp& token, const Shape& shape,
                               const string& config);
  friend XlaOp OutfeedWithToken(const XlaOp& operand, const XlaOp& token,
                                const Shape& shape_with_layout,
                                const string& outfeed_config);
  friend XlaOp CreateToken(XlaBuilder* builder);
  friend XlaOp AfterAll(XlaBuilder* builder, absl::Span<const XlaOp> tokens);
};

// RAII-style object: sets the current sharding assignment in builder on
// construction, and sets back to the previous assignment on destruction.
class XlaScopedShardingAssignment {
 public:
  XlaScopedShardingAssignment(xla::XlaBuilder* builder,
                              absl::optional<OpSharding> sharding)
      : builder_(builder), prev_sharding_(builder->sharding()) {
    SetSharding(sharding);
  }

  XlaScopedShardingAssignment(const XlaScopedShardingAssignment&) = delete;
  XlaScopedShardingAssignment& operator=(const XlaScopedShardingAssignment&) =
      delete;

  ~XlaScopedShardingAssignment() { SetSharding(prev_sharding_); }

 private:
  void SetSharding(const absl::optional<OpSharding>& sharding) {
    if (sharding.has_value()) {
      builder_->SetSharding(sharding.value());
    } else {
      builder_->ClearSharding();
    }
  }

  xla::XlaBuilder* const builder_;
  absl::optional<OpSharding> prev_sharding_;
};

// Free functions for building XlaOps. The intention is that these will
// become the public API for building XlaOps rather than calling methods on
// XlaBuilder directly.

// Enqueues a "retrieve parameter value" instruction for a parameter that was
// passed to the computation.
XlaOp Parameter(XlaBuilder* builder, int64 parameter_number, const Shape& shape,
                const string& name);

// Enqueues a constant with the value of the given literal onto the
// computation.
XlaOp ConstantLiteral(XlaBuilder* builder, const LiteralSlice& literal);

// Enqueues a constant onto the computation. Methods are templated on the
// native host type (NativeT) which corresponds to a specific XLA
// PrimitiveType as given in the following table:
//
//  Native Type   PrimitiveType
// -----------------------------
//   bool           PRED
//   int32          S32
//   int64          S64
//   uint32         U32
//   uint64         U64
//   float          F32
//   double         F64
//
// Note: not all primitive types defined in xla_data.proto have a
// corresponding native type yet.
template <typename NativeT>
XlaOp ConstantR0(XlaBuilder* builder, NativeT value);
template <typename NativeT>
XlaOp ConstantR1(XlaBuilder* builder, absl::Span<const NativeT> values);
XlaOp ConstantR1(XlaBuilder* builder, const tensorflow::core::Bitmap& values);
template <typename NativeT>
XlaOp ConstantR2(XlaBuilder* builder,
                 std::initializer_list<std::initializer_list<NativeT>> values);
template <typename NativeT>
XlaOp ConstantFromArrayWithLayout(XlaBuilder* builder,
                                  const Array<NativeT>& values,
                                  const Layout& layout);
template <typename NativeT>
XlaOp ConstantFromArray(XlaBuilder* builder, const Array<NativeT>& values);
template <typename NativeT>
XlaOp ConstantR2FromArray2DWithLayout(XlaBuilder* builder,
                                      const Array2D<NativeT>& values,
                                      const Layout& layout);
template <typename NativeT>
XlaOp ConstantR2FromArray2D(XlaBuilder* builder,
                            const Array2D<NativeT>& values);
template <typename NativeT>
XlaOp ConstantR3FromArray3DWithLayout(XlaBuilder* builder,
                                      const Array3D<NativeT>& values,
                                      const Layout& layout);
template <typename NativeT>
XlaOp ConstantR3FromArray3D(XlaBuilder* builder,
                            const Array3D<NativeT>& values);
template <typename NativeT>
XlaOp ConstantR4FromArray4DWithLayout(XlaBuilder* builder,
                                      const Array4D<NativeT>& values,
                                      const Layout& layout);
template <typename NativeT>
XlaOp ConstantR4FromArray4D(XlaBuilder* builder,
                            const Array4D<NativeT>& values);

// Enqueues a rank one constant (XlaBuilder* builder, vector) onto the
// computation. The vector has size 'length' and every element has the value
// 'value'.
template <typename NativeT>
XlaOp ConstantR1(XlaBuilder* builder, int64 length, NativeT value);

// Adds dimensions to an array by duplicating the data in the array.
//
// The new dimensions are inserted on the left, i.e. if
// broadcast_sizes has values {a0, ..., aN} and the operand shape
// has dimensions {b0, ..., bM} then the shape of the output has
// dimensions {a0, ..., aN, b0, ..., bM}.
//
// The new dimensions index into copies of the operand, i.e.
//
//   output[i0, ..., iN, j0, ..., jM] = operand[j0, ..., jM]
XlaOp Broadcast(const XlaOp& operand, absl::Span<const int64> broadcast_sizes);

// Performs in-dimension-style broadcast.
//
// Operand specifies the input to be broadcast. "shape" is expected output
// shape. "broadcast_dimensions" are the dimensions to be broadcasting into.
// Dimension numbers in broadcast_dimensions map to individual dimensions
// of the operand, and specify what dimension of the output shape they
// should be broadcast.
// e.g.
// Say operand = [1, 2], i.e., a 1D tensor with 2 elements.
// and dimension of shape is [2,2].
// Specifying {1} as brodcast_dimension will generate output
// [1 , 2]
// [1 , 2]
// On the other hand, specifying {0} as broadcast_dimension
// will generate output
// [1 , 1]
// [2 , 2]
XlaOp BroadcastInDim(const XlaOp& operand, const Shape& shape,
                     const absl::Span<const int64> broadcast_dimensions);

// Enqueues a pad operation onto the computation that pads the given value on
// the edges as well as between the elements of the input. padding_config
// specifies the padding amount for each dimension.
XlaOp Pad(const XlaOp& operand, const XlaOp& padding_value,
          const PaddingConfig& padding_config);

// Enqueues an operation onto the computation that flattens the operand based
// on the dimension order (major/slowest-varying to minor/fastest-varying)
// given, followed by reshaping it into the shape with the given dimension
// sizes (also major to minor). Conceptually, this is a limited form of
// "shape casting".
XlaOp Reshape(const XlaOp& operand, absl::Span<const int64> dimensions,
              absl::Span<const int64> new_sizes);

// Enqueues an operation onto the computation that collapses the operand, from
// first to last dimension (C order), then reshapes it to the given dimension
// sizes. Conceptually, this is a limited form of "shape casting".
XlaOp Reshape(const XlaOp& operand, absl::Span<const int64> new_sizes);

// Wrapper for Reshape.
// Enqueues an operation to collapse the provided dimensions; e.g. an
// operand with dimensions {x=256, y=2, z=2, p=32} can be collapsed to
// {x=1024, y=32} by collapsing dims {0, 1, 2}. Collapsing dimensions must
// be a consecutive, in-order subsequence of the operand dimensions.
//
// Note that collapsing a single dimension does nothing:
//
//    {256} collapsing {0} => {256}
//    {1} collapsing {0} => {1}
//
// Collapsing multiple dimensions produces a single result dimension:
//
//    {256, 2} collapsing {0,1} => {512}
//    {256, 2, 3} collapsing {0,1} => {512, 3}
//
// This could potentially cause data to be moved -- it provides a more
// structured form of reshaping than an arbitrary Reshape operation.
XlaOp Collapse(const XlaOp& operand, absl::Span<const int64> dimensions);

// Enqueues a slice operation onto the computation that slices the operand
// from the start indices to the limit indices; e.g.
//
//        x
//   [ 0 1 2 3 ]
// y [ 4 5 6 7 ] => slice(start={1, 1}, limit={2, 3}) => [ 5 6 ]
//   [ 8 9 a b ]
//
// Note that "limit" means up-to-but-not-including; i.e. [start, limit) in 1D
// range notation.
// The strides parameter determines the stride over the slice
XlaOp Slice(const XlaOp& operand, absl::Span<const int64> start_indices,
            absl::Span<const int64> limit_indices,
            absl::Span<const int64> strides);

// Enqueues a slice operation in a given dimension, taking all other
// dimensions as they are; e.g. if dimno is 1 from start_index 2 to
// limit_index 4 by 1, and the shape is f32[7,8,9], this call is short-hand
// for:
//
//  array[:, 2:4:1, :]
XlaOp SliceInDim(const XlaOp& operand, int64 start_index, int64 limit_index,
                 int64 stride, int64 dimno);

// Enqueues a slice operation onto the computation that slices the 'operand'
// from dynamic start indices which are passed in 'start_indices'.
// The size of the slice in each dimension is passed in 'slice_sizes',
// which specify the end point of exclusive slice intervals in each
// dimension [start, start + size).
// The shape of 'start_indices' must be rank == 1, with dimension size
// equal to the rank of the 'operand'.
// Slice index calculations are computed modulo input dimension sizes to
// prevent dynamic start indices from generating out-of-bound array accesses.
XlaOp DynamicSlice(const XlaOp& operand, const XlaOp& start_indices,
                   absl::Span<const int64> slice_sizes);

// Enqueues a dynamic update slice operation onto the computation, which
// updates a slice of 'operand' with 'update' at dynamic 'start_indices'.
// The shape of 'update' determines the shape of the slice of 'operand'
// which is updated.
// The indices specified in 'start_indices' specify the offset of the slice
// of 'operand' which is updated.
//
//               update = {10, 11} // calculated at runtime.
//   [1 2 3]     start  = {1, 1}   // calculated at runtime.  [1 2  3 ]
//   [4 5 6]  => DynamicUpdateslice(data, update, start)   => [4 10 11]
//   [7 8 9]                                                  [7 8  9 ]
//
// The shape of 'start_indices' must be rank == 1, with dimension size
// equal to the rank of the 'operand'.
// Slice index calculations are computed modulo update dimension sizes to
// prevent dynamic start indices from generating out-of-bound array accesses.
XlaOp DynamicUpdateSlice(const XlaOp& operand, const XlaOp& update,
                         const XlaOp& start_indices);

// Enqueues a concatenate instruction onto the computation. 'operands' must
// have >= 1 entry.
XlaOp ConcatInDim(XlaBuilder* builder, absl::Span<const XlaOp> operands,
                  int64 dimension);

// Enqueue a tracing operation onto the computation; the computation will emit
// a logging message with the operand.
void Trace(const string& tag, const XlaOp& operand);

// Enqueues a conditional-move-like select operation onto the computation;
// predicated on pred, selects between on_true and on_false.
XlaOp Select(const XlaOp& pred, const XlaOp& on_true, const XlaOp& on_false);

// Enqueues a tuple-creation instruction onto the computation.
XlaOp Tuple(XlaBuilder* builder, absl::Span<const XlaOp> elements);

// Enqueues a tuple-element-get instruction onto the computation.
XlaOp GetTupleElement(const XlaOp& tuple_data, int64 index);

// Enqueues an equal-to comparison instruction onto the computation.
XlaOp Eq(const XlaOp& lhs, const XlaOp& rhs,
         absl::Span<const int64> broadcast_dimensions = {});

// Enqueues a not-equal comparison instruction onto the computation.
XlaOp Ne(const XlaOp& lhs, const XlaOp& rhs,
         absl::Span<const int64> broadcast_dimensions = {});

// Enqueues a greater-or-equal comparison instruction onto the computation.
XlaOp Ge(const XlaOp& lhs, const XlaOp& rhs,
         absl::Span<const int64> broadcast_dimensions = {});

// Enqueues a greater-than comparison instruction onto the computation.
XlaOp Gt(const XlaOp& lhs, const XlaOp& rhs,
         absl::Span<const int64> broadcast_dimensions = {});

// Enqueues a less-than comparison instruction onto the computation.
XlaOp Lt(const XlaOp& lhs, const XlaOp& rhs,
         absl::Span<const int64> broadcast_dimensions = {});

// Enqueues a less-or-equal comparison instruction onto the computation.
XlaOp Le(const XlaOp& lhs, const XlaOp& rhs,
         absl::Span<const int64> broadcast_dimensions = {});

// Enqueues a dot instruction onto the computation.
XlaOp Dot(const XlaOp& lhs, const XlaOp& rhs,
          const PrecisionConfig* precision_config = nullptr);

// Enqueues a general dot instruction onto the computation.
XlaOp DotGeneral(const XlaOp& lhs, const XlaOp& rhs,
                 const DotDimensionNumbers& dimension_numbers,
                 const PrecisionConfig* precision_config = nullptr);

// Enqueues a convolution instruction onto the computation, which uses the
// default convolution dimension numbers.
XlaOp Conv(const XlaOp& lhs, const XlaOp& rhs,
           absl::Span<const int64> window_strides, Padding padding,
           int64 feature_group_count = 1,
           const PrecisionConfig* precision_config = nullptr);

// Enqueues a convolution instruction onto the computation, with the caller
// provided padding configuration in the format returned by MakePadding().
XlaOp ConvWithGeneralPadding(const XlaOp& lhs, const XlaOp& rhs,
                             absl::Span<const int64> window_strides,
                             absl::Span<const std::pair<int64, int64>> padding,
                             int64 feature_group_count = 1,
                             const PrecisionConfig* precision_config = nullptr);

// Enqueues a convolution instruction onto the computation, with the caller
// provided dimension numbers configuration.
XlaOp ConvWithGeneralDimensions(
    const XlaOp& lhs, const XlaOp& rhs, absl::Span<const int64> window_strides,
    Padding padding, const ConvolutionDimensionNumbers& dimension_numbers,
    int64 feature_group_count = 1,
    const PrecisionConfig* precision_config = nullptr);

// Enqueues a convolution instruction onto the computation, with the caller
// provided padding configuration as well as the dimension numbers.
XlaOp ConvGeneral(const XlaOp& lhs, const XlaOp& rhs,
                  absl::Span<const int64> window_strides,
                  absl::Span<const std::pair<int64, int64>> padding,
                  const ConvolutionDimensionNumbers& dimension_numbers,
                  int64 feature_group_count = 1,
                  const PrecisionConfig* precision_config = nullptr);

// Enqueues a convolution instruction onto the computation, with the caller
// provided padding configuration, dilation factors and dimension numbers.
XlaOp ConvGeneralDilated(const XlaOp& lhs, const XlaOp& rhs,
                         absl::Span<const int64> window_strides,
                         absl::Span<const std::pair<int64, int64>> padding,
                         absl::Span<const int64> lhs_dilation,
                         absl::Span<const int64> rhs_dilation,
                         const ConvolutionDimensionNumbers& dimension_numbers,
                         int64 feature_group_count = 1,
                         const PrecisionConfig* precision_config = nullptr);

// Enqueues an FFT instruction onto the computation, of the given type and
// with the given FFT length.
XlaOp Fft(const XlaOp& operand, FftType fft_type,
          absl::Span<const int64> fft_length);

// Enqueues an infeed instruction onto the computation, which writes data of
// the given shape to the infeed buffer of the device.
XlaOp Infeed(XlaBuilder* builder, const Shape& shape,
             const string& config = "");

// Variant of Infeed which takes a token-shaped operand and produces a
// two-element tuple containing the data value and a token-shaped value.
// Tokens are used for ordering side-effecting operations.
// TODO(b/110532604): Replace all uses of the non-token form with this variant.
XlaOp InfeedWithToken(const XlaOp& token, const Shape& shape,
                      const string& config = "");

// Enqueues an outfeed instruction onto the computation. This instruction
// generates outgoing data transfers for the given data.
//
// shape_with_layout communicates the laid out shape that we want to outfeed
// -- if !ShapeUtil::Compatible(GetShape(operand), shape_with_layout) an error
// will occur.
void Outfeed(const XlaOp& operand, const Shape& shape_with_layout,
             const string& outfeed_config);

// Variant of Outfeed which takes a token-shaped operand and produces a
// token-shaped value. Tokens are used for ordering side-effecting operations.
// TODO(b/110532604): Replace all uses of the non-token form with this variant.
XlaOp OutfeedWithToken(const XlaOp& operand, const XlaOp& token,
                       const Shape& shape_with_layout,
                       const string& outfeed_config);

// Enqueues a call instruction onto the computation.
XlaOp Call(XlaBuilder* builder, const XlaComputation& computation,
           absl::Span<const XlaOp> operands);

// Enqueues a custom call instruction onto the computation. A custom call
// invokes code external to XLA. The |operands| are passed to the external code,
// and the external code is expected to produce a result of the given
// |shape|. The exact mechanism is backend-specific. For example, in the CPU
// backend, a call instruction is emitted which targets a symbol with the name
// |call_target_name|.  |call_target_name| and |opaque| can arbitrary strings,
// but |call_target_name| should be short as it may be used in labels. |opaque|
// can encode arbitrarily large amounts of information.
XlaOp CustomCall(XlaBuilder* builder, const string& call_target_name,
                 absl::Span<const XlaOp> operands, const Shape& shape,
                 const string& opaque = "");

// Overload which constructs a custom call with fixed layouts. The operands will
// have the layouts specified by |operand_shapes_with_layout| when provided to
// external code, and the external code is expected to produce a result with the
// layout specified by |shape_with_layout|. All shapes in |shape_with_layout|
// and |operand_shapes_with_layout| must have layouts.
XlaOp CustomCallWithLayout(XlaBuilder* builder, const string& call_target_name,
                           absl::Span<const XlaOp> operands,
                           const Shape& shape_with_layout,
                           absl::Span<const Shape> operand_shapes_with_layout,
                           const string& opaque = "");

// The following methods enqueue element-wise binary arithmetic operations
// onto the computation. The shapes of the operands have to match unless one
// of the operands is a scalar, or an explicit broadcast dimension is given
// (see g3doc for more details).

// Enqueues a complex compose instruction onto the computation.
XlaOp Complex(const XlaOp& real, const XlaOp& imag,
              absl::Span<const int64> broadcast_dimensions = {});

// Enqueues a complex conjugate instruction onto the computation.
XlaOp Conj(const XlaOp& operand);

// Enqueues an add instruction onto the computation.
XlaOp Add(const XlaOp& lhs, const XlaOp& rhs,
          absl::Span<const int64> broadcast_dimensions = {});

// Enqueues a subtract instruction onto the computation.
XlaOp Sub(const XlaOp& lhs, const XlaOp& rhs,
          absl::Span<const int64> broadcast_dimensions = {});

// Enqueues a multiply instruction onto the computation.
XlaOp Mul(const XlaOp& lhs, const XlaOp& rhs,
          absl::Span<const int64> broadcast_dimensions = {});

// Enqueues a divide instruction onto the computation.
XlaOp Div(const XlaOp& lhs, const XlaOp& rhs,
          absl::Span<const int64> broadcast_dimensions = {});

// Enqueues a remainder instruction onto the computation.
XlaOp Rem(const XlaOp& lhs, const XlaOp& rhs,
          absl::Span<const int64> broadcast_dimensions = {});

// Enqueues a max instruction onto the computation.
XlaOp Max(const XlaOp& lhs, const XlaOp& rhs,
          absl::Span<const int64> broadcast_dimensions = {});

// Enqueues a min instruction onto the computation.
XlaOp Min(const XlaOp& lhs, const XlaOp& rhs,
          absl::Span<const int64> broadcast_dimensions = {});

// Element-wise logical operators
XlaOp And(const XlaOp& lhs, const XlaOp& rhs,
          absl::Span<const int64> broadcast_dimensions = {});

XlaOp Or(const XlaOp& lhs, const XlaOp& rhs,
         absl::Span<const int64> broadcast_dimensions = {});

XlaOp Xor(const XlaOp& lhs, const XlaOp& rhs,
          absl::Span<const int64> broadcast_dimensions = {});

XlaOp Not(const XlaOp& operand);

XlaOp ShiftLeft(const XlaOp& lhs, const XlaOp& rhs,
                absl::Span<const int64> broadcast_dimensions = {});
XlaOp ShiftRightArithmetic(const XlaOp& lhs, const XlaOp& rhs,
                           absl::Span<const int64> broadcast_dimensions = {});
XlaOp ShiftRightLogical(const XlaOp& lhs, const XlaOp& rhs,
                        absl::Span<const int64> broadcast_dimensions = {});

// Reduces an array among the provided dimensions, given "computation" as a
// reduction operator.
XlaOp Reduce(const XlaOp& operand, const XlaOp& init_value,
             const XlaComputation& computation,
             absl::Span<const int64> dimensions_to_reduce);

// Reduces several arrays simultaneously among the provided dimensions, given
// "computation" as a reduction operator.
XlaOp Reduce(XlaBuilder* builder, absl::Span<const XlaOp> operands,
             absl::Span<const XlaOp> init_values,
             const XlaComputation& computation,
             absl::Span<const int64> dimensions_to_reduce);

// Convenience wrapper around the above that reduces all the dimensions in the
// operand shape.
XlaOp ReduceAll(const XlaOp& operand, const XlaOp& init_value,
                const XlaComputation& computation);

// Enqueues a windowed reduce instruction onto the computation.
XlaOp ReduceWindow(const XlaOp& operand, const XlaOp& init_value,
                   const XlaComputation& computation,
                   absl::Span<const int64> window_dimensions,
                   absl::Span<const int64> window_strides, Padding padding);

// As ReduceWindow(), but the padding is given in the format
// returned by MakePadding().
XlaOp ReduceWindowWithGeneralPadding(
    const XlaOp& operand, const XlaOp& init_value,
    const XlaComputation& computation,
    absl::Span<const int64> window_dimensions,
    absl::Span<const int64> window_strides,
    absl::Span<const int64> base_dilations,
    absl::Span<const int64> window_dilations,
    absl::Span<const std::pair<int64, int64>> padding);

// Returns the sum of the operand value within each subgroup of replicas. All
// replicas supply one input to the sum and all replicas receive the resulting
// sum for each subgroup.
XlaOp CrossReplicaSum(const XlaOp& operand,
                      absl::Span<const ReplicaGroup> replica_groups = {});

// Enqueues an operation that do an AllReduce of the operand cross cores. Here
// AllReduce means doing a reduction on the input operand cross cores and then
// broadcasting the reduction result to those cores. The reduction function is
// defined by `computation`, which should be a commutative computation on
// scalars, e.g., add, min, or max. The way that AllReduce is applied is
// configured by:
//
// - `replica_groups`: each ReplicaGroup contains a list of replica id. If
// empty, all replicas belong to one group. Allreduce will be applied within
// subgroups. For example, we have 4 replicas, then replica_groups={{0,2},{1,3}}
// means, replica 0 and 2 are in subgroup 0, replica 1 and 3 are in subgroup 1.
//
// - `channel_id`: for Allreduce nodes from different modules, if they have the
// same channel_id, they will be 'Allreduce'd. If empty, Allreduce will not be
// applied cross modules.
//
// TODO(b/79737069): Rename this to AllReduce when it's ready to use.
XlaOp CrossReplicaSum(
    const XlaOp& operand, const XlaComputation& computation,
    absl::Span<const ReplicaGroup> replica_groups = {},
    const absl::optional<ChannelHandle>& channel_id = absl::nullopt);

// Enqueues an operation that do an Alltoall of the operand cross cores.
XlaOp AllToAll(const XlaOp& operand, int64 split_dimension,
               int64 concat_dimension, int64 split_count,
               const std::vector<ReplicaGroup>& replica_groups = {});

// Enqueues an collective operation that sends and receives data cross replicas.
//
// - `source_target_pair`: a list of (source_replica_id, target_replica_id)
// pairs. For each pair, the operand is sent from source replica to target
// replica. Note that, 1) any two pairs should not have the same target replica
// id, and they should not have the same source replica id; 2) if a replica id
// is not a target in any pair, then the output on that replica is a tensor
// consists of 0(s) with the same shape as the input.
XlaOp CollectivePermute(
    const XlaOp& operand,
    const std::vector<std::pair<int64, int64>>& source_target_pairs);

// Enqueues an operation that scatters the `source` array to the selected
// indices of each window.
XlaOp SelectAndScatter(const XlaOp& operand, const XlaComputation& select,
                       absl::Span<const int64> window_dimensions,
                       absl::Span<const int64> window_strides, Padding padding,
                       const XlaOp& source, const XlaOp& init_value,
                       const XlaComputation& scatter);

// As SelectAndScatter(), but the padding is given in the format
// returned by MakePadding().
XlaOp SelectAndScatterWithGeneralPadding(
    const XlaOp& operand, const XlaComputation& select,
    absl::Span<const int64> window_dimensions,
    absl::Span<const int64> window_strides,
    absl::Span<const std::pair<int64, int64>> padding, const XlaOp& source,
    const XlaOp& init_value, const XlaComputation& scatter);

// Enqueues an abs instruction onto the computation.
XlaOp Abs(const XlaOp& operand);

// Enqueues a atan2 instruction onto the computation.
XlaOp Atan2(const XlaOp& y, const XlaOp& x,
            absl::Span<const int64> broadcast_dimensions = {});

// Enqueues an exp instruction onto the computation.
XlaOp Exp(const XlaOp& operand);

// Enqueues an expm1 instruction onto the computation.
XlaOp Expm1(const XlaOp& operand);

// Enqueues a floor instruction onto the computation.
XlaOp Floor(const XlaOp& operand);

// Enqueues a ceil instruction onto the computation.
XlaOp Ceil(const XlaOp& operand);

// Enqueues a round instruction onto the computation, rounding to nearest even
// with half-way cases rounding away from zero.
XlaOp Round(const XlaOp& operand);

// Enqueues an log instruction (natural logarithm) onto the computation.
XlaOp Log(const XlaOp& operand);

// Enqueues an log1p instruction (log(x+1)) onto the computation.
XlaOp Log1p(const XlaOp& operand);

// Enqueues a sign instruction onto the computation.
XlaOp Sign(const XlaOp& operand);

// Enqueues a count leading zeros instruction onto the computation.
XlaOp Clz(const XlaOp& operand);

// Enqueues a cosine instruction onto the computation.
XlaOp Cos(const XlaOp& operand);

// Enqueues a sine instruction onto the computation.
XlaOp Sin(const XlaOp& operand);

// Enqueues a tanh instruction onto the computation.
XlaOp Tanh(const XlaOp& operand);

// Enqueues a real-part instruction onto the computation.
XlaOp Real(const XlaOp& operand);

// Enqueues an imaginary-part instruction onto the computation.
XlaOp Imag(const XlaOp& operand);

// Enqueues a lhs^rhs computation onto the computation.
XlaOp Pow(const XlaOp& lhs, const XlaOp& rhs,
          absl::Span<const int64> broadcast_dimensions = {});

// Enqueues an operator that tests if the operand's values are finite, i.e.,
// not Inf or NaN. Defined only for floating-point types. Returns an array of
// booleans with the same shape where entries are true iff the corresponding
// entry was NaN.
XlaOp IsFinite(const XlaOp& operand);

// Enqueues an iota operation onto the computation.
XlaOp Iota(XlaBuilder* builder, const Shape& shape, int64 iota_dimension);

// Enqueues a rank-1 iota operation onto the computation.
XlaOp Iota(XlaBuilder* builder, PrimitiveType type, int64 size);

// Enqueues a convert instruction onto the computation that changes the
// element type of the operand array to primitive_type.
XlaOp ConvertElementType(const XlaOp& operand, PrimitiveType new_element_type);

// Enqueues a no-op instruction onto the computation that changes
// the element type of the operand array to primitive_type. The
// bit-widths of the source and destination element types must be
// identical.
XlaOp BitcastConvertType(const XlaOp& operand, PrimitiveType new_element_type);

// Enqueues a negate instruction onto the computation.
XlaOp Neg(const XlaOp& operand);

// Enqueues a transpose instruction onto the computation.
XlaOp Transpose(const XlaOp& operand, absl::Span<const int64> permutation);

// Enqueues a reverse instruction onto the computation. The order of the
// elements in the given dimensions is reversed (i.e., the element at index i
// is moved to index dimension_size - 1 - i).
XlaOp Rev(const XlaOp& operand, absl::Span<const int64> dimensions);

// Enqueues a sort (as increasing order) instruction onto the computation.
// If only keys are provided:
// * If the keys are an rank-1 tensor (an array), the result is a sorted array
// of keys, in ascending order.
// * If the keys have higher rank, the keys are sorted along the provided
// dimension. For example, for a rank-2 tensor (a matrix) of keys, a dimension
// value of 0 will indepenently sort every column, and a dimension value of 1
// will independently sort each row. If no dimension number is provided, then
// the last dimension is chosen by default.
//
// If both keys and values are provided:
// * The keys and the values must tensors with the same dimensions. The
// element types of the tensors may be different.
// * The result is a tuple that consists of a sorted tensor of keys (along the
// provided dimension, as above) as the first element, and a tensor with their
// corresponding values as the second element.
XlaOp Sort(XlaOp keys, absl::optional<XlaOp> values = absl::nullopt,
           int64 dimension = -1);

// Enqueues a clamp instruction onto the computation.
XlaOp Clamp(const XlaOp& min, const XlaOp& operand, const XlaOp& max);

// Enqueues a map instruction onto the computation.
XlaOp Map(XlaBuilder* builder, absl::Span<const XlaOp> operands,
          const XlaComputation& computation, absl::Span<const int64> dimensions,
          absl::Span<const XlaOp> static_operands = {});

// Enqueues a N(mu, sigma) random number generation instruction onto the
// computation.
XlaOp RngNormal(const XlaOp& mu, const XlaOp& sigma, const Shape& shape);

// Enqueues a U(a, b) random number generation instruction onto the
// computation. Returns values in the semi-open interval [a, b).
XlaOp RngUniform(const XlaOp& a, const XlaOp& b, const Shape& shape);

// Enqueues a while node onto the computation.
XlaOp While(const XlaComputation& condition, const XlaComputation& body,
            const XlaOp& init);

// Enqueues a conditional node onto the computation.
XlaOp Conditional(const XlaOp& predicate, const XlaOp& true_operand,
                  const XlaComputation& true_computation,
                  const XlaOp& false_operand,
                  const XlaComputation& false_computation);

// Enqueues a ReducePrecision node onto the computation.
XlaOp ReducePrecision(const XlaOp& operand, const int exponent_bits,
                      const int mantissa_bits);

// Enqueues a Gather node onto the computation.
XlaOp Gather(const XlaOp& input, const XlaOp& start_indices,
             const GatherDimensionNumbers& dimension_numbers,
             absl::Span<const int64> slice_sizes);

// Enqueues a Scatter node onto the computation.
XlaOp Scatter(const XlaOp& input, const XlaOp& scatter_indices,
              const XlaOp& updates, const XlaComputation& update_computation,
              const ScatterDimensionNumbers& dimension_numbers);

// Enqueues a Send node onto the computation for device-to-device
// communication. This operation sends the given operand to
// a Recv instruction in a different computation that shares the same channel
// handle.
void Send(const XlaOp& operand, const ChannelHandle& handle);

// Variant of Send which takes a token-shaped operand and produces a
// token-shaped value.  Tokens are used for ordering side-effecting operations.
// TODO(b/110532604): Replace all uses of the non-token form with this variant.
XlaOp SendWithToken(const XlaOp& operand, const XlaOp& token,
                    const ChannelHandle& handle);

// Enqueues a Recv node onto the computation for device-to-device
// communication. The data comes from a Send instruction in a different
// computation that shares the same channel handle and its shape must be the
// same as the given shape.
XlaOp Recv(XlaBuilder* builder, const Shape& shape,
           const ChannelHandle& handle);

// Variant of Recv which takes a token-shaped operand and produces a two-element
// tuple containing the data value and a token-shaped value. Tokens are used
// for ordering side-effecting operations.
// TODO(b/110532604): Replace all uses of the non-token form with this variant.
XlaOp RecvWithToken(const XlaOp& token, const Shape& shape,
                    const ChannelHandle& handle);

// Enqueues a Send node which transfers data from the device to the host. The
// 'shape_with_layout' argument defines the layout of the data transferred; its
// shape must be compatible with the shape of the operand. The operand must be
// array-shaped.
// TODO(b/111544877): Support tuple shapes.
XlaOp SendToHost(const XlaOp& operand, const XlaOp& token,
                 const Shape& shape_with_layout, const ChannelHandle& handle);

// Enqueues a Recv node which transfers data from the host to the device. The
// given shape must contain a layout and must be an array.
// TODO(b/111544877): Support tuple shapes.
XlaOp RecvFromHost(const XlaOp& token, const Shape& shape,
                   const ChannelHandle& handle);

// Enqueues an operation (AfterAll) with no operands that produces a
// token-shaped value.  Tokens are used for ordering side-effecting operations.
// This is a separate method from AfterAll to facility the removal of
// operand-less AfterAll instructions.
// TODO(b/110532604): Remove this function when all tokens are derived from a
// single token generated or passed into the entry computation.
XlaOp CreateToken(XlaBuilder* builder);

// Enqueues an AfterAll instruction which produces a token-shaped value and
// takes a variadic number of token-shaped operands. The number of operands must
// be greater than zero. Used for joining tokens.
XlaOp AfterAll(XlaBuilder* builder, absl::Span<const XlaOp> tokens);

// Normalizes operand across spatial and batch dimensions for each feature.
//
// Returns a tuple (normalized, batch_mean, batch_var) where `normalized`
// is the normalized result and batch_mean and batch_var are the mean and
// variance, respectively, across batch for the operand.
XlaOp BatchNormTraining(const XlaOp& operand, const XlaOp& scale,
                        const XlaOp& offset, float epsilon,
                        int64 feature_index);

// Normalizes operand across spatial and batch dimensions for each feature.
//
// `BatchNormInference` is equivalent to calling `BatchNormTraining` without
// computing `mean` and `variance` for each batch inside the operation. It
// uses the input `mean` and `variance` instead as estimated values. The
// purpose of this op is to reduce latency in inference, hence the name
// `BatchNormInference`.
//
// The output has the same shape as `operand`, and contains the normalized
// values for each batch.
XlaOp BatchNormInference(const XlaOp& operand, const XlaOp& scale,
                         const XlaOp& offset, const XlaOp& mean,
                         const XlaOp& variance, float epsilon,
                         int64 feature_index);

// Calculates the gradients of a batch norm op.
//
// The inputs `batch_mean` and `batch_var` represent the mean and variance
// across the batch.
//
// Returns a tuple of three elements:
//   - grad_operand: Gradient with respect to input `operand`
//   - grad_offset: Gradient with respect to input `offset`
//   - grad_scale: Gradient with respect to input `scale`
XlaOp BatchNormGrad(const XlaOp& operand, const XlaOp& scale,
                    const XlaOp& batch_mean, const XlaOp& batch_var,
                    const XlaOp& grad_output, float epsilon,
                    int64 feature_index);

// Implementation details below this point.

template <typename NativeT>
XlaOp XlaBuilder::ConstantR0(NativeT value) {
  return ConstantLiteral(LiteralUtil::CreateR0<NativeT>(value));
}

template <typename NativeT>
XlaOp XlaBuilder::ConstantR1(absl::Span<const NativeT> values) {
  return ConstantLiteral(LiteralUtil::CreateR1<NativeT>(values));
}

template <typename NativeT>
XlaOp XlaBuilder::ConstantR1(int64 length, NativeT value) {
  Literal literal(ShapeUtil::MakeShape(
      primitive_util::NativeToPrimitiveType<NativeT>(), {length}));
  literal.PopulateWithValue(value);
  return ConstantLiteral(literal);
}

inline XlaOp XlaBuilder::ConstantR1(const tensorflow::core::Bitmap& values) {
  return ConstantLiteral(LiteralUtil::CreateR1(values));
}

template <typename NativeT>
XlaOp XlaBuilder::ConstantR2(
    std::initializer_list<std::initializer_list<NativeT>> values) {
  return ConstantLiteral(LiteralUtil::CreateR2<NativeT>(values));
}

template <typename NativeT>
XlaOp XlaBuilder::ConstantFromArrayWithLayout(const Array<NativeT>& values,
                                              const Layout& layout) {
  return ConstantLiteral(
      LiteralUtil::CreateFromArrayWithLayout<NativeT>(values, layout));
}

template <typename NativeT>
XlaOp XlaBuilder::ConstantFromArray(const Array<NativeT>& values) {
  return ConstantLiteral(LiteralUtil::CreateFromArray<NativeT>(values));
}

template <typename NativeT>
XlaOp XlaBuilder::ConstantR2FromArray2DWithLayout(
    const Array2D<NativeT>& values, const Layout& layout) {
  return ConstantLiteral(
      LiteralUtil::CreateFromArrayWithLayout<NativeT>(values, layout));
}

template <typename NativeT>
XlaOp XlaBuilder::ConstantR2FromArray2D(const Array2D<NativeT>& values) {
  return ConstantLiteral(LiteralUtil::CreateR2FromArray2D<NativeT>(values));
}

template <typename NativeT>
XlaOp XlaBuilder::ConstantR3FromArray3DWithLayout(
    const Array3D<NativeT>& values, const Layout& layout) {
  return ConstantLiteral(
      LiteralUtil::CreateR3FromArray3DWithLayout<NativeT>(values, layout));
}

template <typename NativeT>
XlaOp XlaBuilder::ConstantR3FromArray3D(const Array3D<NativeT>& values) {
  return ConstantFromArray(values);
}

template <typename NativeT>
XlaOp XlaBuilder::ConstantR4FromArray4DWithLayout(
    const Array4D<NativeT>& values, const Layout& layout) {
  return ConstantFromArrayWithLayout(values, layout);
}

template <typename NativeT>
XlaOp XlaBuilder::ConstantR4FromArray4D(const Array4D<NativeT>& values) {
  return ConstantFromArray(values);
}

// Free function template implementations.

template <typename NativeT>
XlaOp ConstantR0(XlaBuilder* builder, NativeT value) {
  return ConstantLiteral(builder, LiteralUtil::CreateR0<NativeT>(value));
}

template <typename NativeT>
XlaOp ConstantR1(XlaBuilder* builder, absl::Span<const NativeT> values) {
  return ConstantLiteral(builder, LiteralUtil::CreateR1<NativeT>(values));
}

template <typename NativeT>
XlaOp ConstantR1(XlaBuilder* builder, int64 length, NativeT value) {
  Literal literal(ShapeUtil::MakeShape(
      primitive_util::NativeToPrimitiveType<NativeT>(), {length}));
  literal.PopulateWithValue(value);
  return ConstantLiteral(builder, literal);
}

inline XlaOp ConstantR1(XlaBuilder* builder,
                        const tensorflow::core::Bitmap& values) {
  return ConstantLiteral(builder, LiteralUtil::CreateR1(values));
}

template <typename NativeT>
XlaOp ConstantR2(XlaBuilder* builder,
                 std::initializer_list<std::initializer_list<NativeT>> values) {
  return ConstantLiteral(builder, LiteralUtil::CreateR2<NativeT>(values));
}

template <typename NativeT>
XlaOp ConstantFromArrayWithLayout(XlaBuilder* builder,
                                  const Array<NativeT>& values,
                                  const Layout& layout) {
  return ConstantLiteral(
      builder, LiteralUtil::CreateFromArrayWithLayout<NativeT>(values, layout));
}

template <typename NativeT>
XlaOp ConstantFromArray(XlaBuilder* builder, const Array<NativeT>& values) {
  return ConstantLiteral(builder,
                         LiteralUtil::CreateFromArray<NativeT>(values));
}

template <typename NativeT>
XlaOp ConstantR2FromArray2DWithLayout(XlaBuilder* builder,
                                      const Array2D<NativeT>& values,
                                      const Layout& layout) {
  return ConstantLiteral(
      builder, LiteralUtil::CreateFromArrayWithLayout<NativeT>(values, layout));
}

template <typename NativeT>
XlaOp ConstantR2FromArray2D(XlaBuilder* builder,
                            const Array2D<NativeT>& values) {
  return ConstantLiteral(builder,
                         LiteralUtil::CreateR2FromArray2D<NativeT>(values));
}

template <typename NativeT>
XlaOp ConstantR3FromArray3DWithLayout(XlaBuilder* builder,
                                      const Array3D<NativeT>& values,
                                      const Layout& layout) {
  return ConstantLiteral(
      builder,
      LiteralUtil::CreateR3FromArray3DWithLayout<NativeT>(values, layout));
}

template <typename NativeT>
XlaOp ConstantR3FromArray3D(XlaBuilder* builder,
                            const Array3D<NativeT>& values) {
  return ConstantFromArray(builder, values);
}

template <typename NativeT>
XlaOp ConstantR4FromArray4DWithLayout(XlaBuilder* builder,
                                      const Array4D<NativeT>& values,
                                      const Layout& layout) {
  return ConstantFromArrayWithLayout(builder, values, layout);
}

template <typename NativeT>
XlaOp ConstantR4FromArray4D(XlaBuilder* builder,
                            const Array4D<NativeT>& values) {
  return ConstantFromArray(builder, values);
}

}  // namespace xla

#endif  // TENSORFLOW_COMPILER_XLA_CLIENT_XLA_BUILDER_H_