aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/compiler/xla/service/shape_inference.cc
blob: dcd726f22c71b4bd709dc63b25d6fdea477c83c7 (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
584
585
586
587
588
589
590
591
592
593
594
595
596
597
598
599
600
601
602
603
604
605
606
607
608
609
610
611
612
613
614
615
616
617
618
619
620
621
622
623
624
625
626
627
628
629
630
631
632
633
634
635
636
637
638
639
640
641
642
643
644
645
646
647
648
649
650
651
652
653
654
655
656
657
658
659
660
661
662
663
664
665
666
667
668
669
670
671
672
673
674
675
676
677
678
679
680
681
682
683
684
685
686
687
688
689
690
691
692
693
694
695
696
697
698
699
700
701
702
703
704
705
706
707
708
709
710
711
712
713
714
715
716
717
718
719
720
721
722
723
724
725
726
727
728
729
730
731
732
733
734
735
736
737
738
739
740
741
742
743
744
745
746
747
748
749
750
751
752
753
754
755
756
757
758
759
760
761
762
763
764
765
766
767
768
769
770
771
772
773
774
775
776
777
778
779
780
781
782
783
784
785
786
787
788
789
790
791
792
793
794
795
796
797
798
799
800
801
802
803
804
805
806
807
808
809
810
811
812
813
814
815
816
817
818
819
820
821
822
823
824
825
826
827
828
829
830
831
832
833
834
835
836
837
838
839
840
841
842
843
844
845
846
847
848
849
850
851
852
853
854
855
856
857
858
859
860
861
862
863
864
865
866
867
868
869
870
871
872
873
874
875
876
877
878
879
880
881
882
883
884
885
886
887
888
889
890
891
892
893
894
895
896
897
898
899
900
901
902
903
904
905
906
907
908
909
910
911
912
913
914
915
916
917
918
919
920
921
922
923
924
925
926
927
928
929
930
931
932
933
934
935
936
937
938
939
940
941
942
943
944
945
946
947
948
949
950
951
952
953
954
955
956
957
958
959
960
961
962
963
964
965
966
967
968
969
970
971
972
973
974
975
976
977
978
979
980
981
982
983
984
985
986
987
988
989
990
991
992
993
994
995
996
997
998
999
1000
1001
1002
1003
1004
1005
1006
1007
1008
1009
1010
1011
1012
1013
1014
1015
1016
1017
1018
1019
1020
1021
1022
1023
1024
1025
1026
1027
1028
1029
1030
1031
1032
1033
1034
1035
1036
1037
1038
1039
1040
1041
1042
1043
1044
1045
1046
1047
1048
1049
1050
1051
1052
1053
1054
1055
1056
1057
1058
1059
1060
1061
1062
1063
1064
1065
1066
1067
1068
1069
1070
1071
1072
1073
1074
1075
1076
1077
1078
1079
1080
1081
1082
1083
1084
1085
1086
1087
1088
1089
1090
1091
1092
1093
1094
1095
1096
1097
1098
1099
1100
1101
1102
1103
1104
1105
1106
1107
1108
1109
1110
1111
1112
1113
1114
1115
1116
1117
1118
1119
1120
1121
1122
1123
1124
1125
1126
1127
1128
1129
1130
1131
1132
1133
1134
1135
1136
1137
1138
1139
1140
1141
1142
1143
1144
1145
1146
1147
1148
1149
1150
1151
1152
1153
1154
1155
1156
1157
1158
1159
1160
1161
1162
1163
1164
1165
1166
1167
1168
1169
1170
1171
1172
1173
1174
1175
1176
1177
1178
1179
1180
1181
1182
1183
1184
1185
1186
1187
1188
1189
1190
1191
1192
1193
1194
1195
1196
1197
1198
1199
1200
1201
1202
1203
1204
1205
1206
1207
1208
1209
1210
1211
1212
1213
1214
1215
1216
1217
1218
1219
1220
1221
1222
1223
1224
1225
1226
1227
1228
1229
1230
1231
1232
1233
1234
1235
1236
1237
1238
1239
1240
1241
1242
1243
1244
1245
1246
1247
1248
1249
1250
1251
1252
1253
1254
1255
1256
1257
1258
1259
1260
1261
1262
1263
1264
1265
1266
1267
1268
1269
1270
1271
1272
1273
1274
1275
1276
1277
1278
1279
1280
1281
1282
1283
1284
1285
1286
1287
1288
1289
1290
1291
1292
1293
1294
1295
1296
1297
1298
1299
1300
1301
1302
1303
1304
1305
1306
1307
1308
1309
1310
1311
1312
1313
1314
1315
1316
1317
1318
1319
1320
1321
1322
1323
1324
1325
1326
1327
1328
1329
1330
1331
1332
1333
1334
1335
1336
1337
1338
1339
1340
1341
1342
1343
1344
1345
1346
1347
1348
1349
1350
1351
1352
1353
1354
1355
1356
1357
1358
1359
1360
1361
1362
1363
1364
1365
1366
1367
1368
1369
1370
1371
1372
1373
1374
1375
1376
1377
1378
1379
1380
1381
1382
1383
1384
1385
1386
1387
1388
1389
1390
1391
1392
1393
1394
1395
1396
1397
1398
1399
1400
1401
1402
1403
1404
1405
1406
1407
1408
1409
1410
1411
1412
1413
1414
1415
1416
1417
1418
1419
1420
1421
1422
1423
1424
1425
1426
1427
1428
1429
1430
1431
1432
1433
1434
1435
1436
1437
1438
1439
1440
1441
1442
1443
1444
1445
1446
1447
1448
1449
1450
1451
1452
1453
1454
1455
1456
1457
1458
1459
1460
1461
1462
1463
1464
1465
1466
1467
1468
1469
1470
1471
1472
1473
1474
1475
1476
1477
1478
1479
1480
1481
1482
1483
1484
1485
1486
1487
1488
1489
1490
1491
1492
1493
1494
1495
1496
1497
1498
1499
1500
1501
1502
1503
1504
1505
1506
1507
1508
1509
1510
1511
1512
1513
1514
1515
1516
1517
1518
1519
1520
1521
1522
1523
1524
1525
1526
1527
1528
1529
1530
1531
1532
1533
1534
1535
1536
1537
1538
1539
1540
1541
1542
1543
1544
1545
1546
1547
1548
1549
1550
1551
1552
1553
1554
1555
1556
1557
1558
1559
1560
1561
1562
1563
1564
1565
1566
1567
1568
1569
1570
1571
1572
1573
1574
1575
1576
1577
1578
1579
1580
1581
1582
1583
1584
1585
1586
1587
1588
1589
1590
1591
1592
1593
1594
1595
1596
1597
1598
1599
1600
1601
1602
1603
1604
1605
1606
1607
1608
1609
1610
1611
1612
1613
1614
1615
1616
1617
1618
1619
1620
1621
1622
1623
1624
1625
1626
1627
1628
1629
1630
1631
1632
1633
1634
1635
1636
1637
1638
1639
1640
1641
1642
1643
1644
1645
1646
1647
1648
1649
1650
1651
1652
1653
1654
1655
1656
1657
1658
1659
1660
1661
1662
1663
1664
1665
1666
1667
1668
1669
1670
1671
1672
1673
1674
1675
1676
1677
1678
1679
1680
1681
1682
1683
1684
1685
1686
1687
1688
1689
1690
1691
1692
1693
1694
1695
1696
1697
1698
1699
1700
1701
1702
1703
1704
1705
1706
1707
1708
1709
1710
1711
1712
1713
1714
1715
1716
1717
1718
1719
1720
1721
1722
1723
1724
1725
1726
1727
1728
1729
1730
1731
1732
1733
1734
1735
1736
1737
1738
1739
1740
1741
1742
1743
1744
1745
1746
1747
1748
1749
1750
1751
1752
1753
1754
1755
1756
1757
1758
1759
1760
1761
1762
1763
1764
1765
1766
1767
1768
1769
1770
1771
1772
1773
1774
1775
1776
1777
1778
1779
1780
1781
1782
1783
1784
1785
1786
1787
1788
1789
1790
1791
1792
1793
1794
1795
1796
1797
1798
1799
1800
1801
1802
1803
1804
1805
1806
1807
1808
1809
1810
1811
1812
1813
1814
1815
1816
1817
1818
1819
1820
1821
1822
1823
1824
1825
1826
1827
1828
1829
1830
1831
1832
1833
1834
1835
1836
1837
1838
1839
1840
1841
1842
1843
1844
1845
1846
1847
1848
1849
1850
1851
1852
1853
1854
1855
1856
1857
1858
1859
1860
1861
1862
1863
1864
1865
1866
1867
1868
1869
1870
1871
1872
1873
1874
1875
1876
1877
1878
1879
1880
1881
1882
1883
1884
1885
1886
1887
1888
1889
1890
1891
1892
1893
1894
1895
1896
1897
1898
1899
1900
1901
1902
1903
1904
1905
1906
1907
1908
1909
1910
1911
1912
1913
1914
1915
1916
1917
1918
1919
1920
1921
1922
1923
1924
1925
1926
1927
1928
1929
1930
1931
1932
1933
1934
1935
1936
1937
1938
1939
1940
1941
1942
1943
1944
1945
1946
1947
1948
1949
1950
1951
1952
1953
1954
1955
1956
1957
1958
1959
1960
1961
1962
1963
1964
1965
1966
1967
1968
1969
1970
1971
1972
1973
1974
1975
1976
1977
1978
1979
1980
1981
1982
1983
1984
1985
1986
1987
1988
1989
1990
1991
1992
1993
1994
1995
1996
1997
1998
1999
2000
2001
2002
2003
2004
2005
2006
2007
2008
2009
2010
2011
2012
2013
2014
2015
2016
2017
2018
2019
2020
2021
2022
2023
2024
2025
2026
2027
2028
2029
2030
2031
2032
2033
2034
2035
2036
2037
2038
2039
2040
2041
2042
2043
2044
2045
2046
2047
2048
2049
2050
2051
2052
2053
2054
2055
2056
2057
2058
2059
2060
2061
2062
2063
2064
2065
2066
2067
2068
2069
2070
2071
2072
2073
2074
/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.

Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at

    http://www.apache.org/licenses/LICENSE-2.0

Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/

#include "tensorflow/compiler/xla/service/shape_inference.h"

#include <stddef.h>
#include <algorithm>
#include <numeric>
#include <set>
#include <string>

#include "tensorflow/compiler/xla/shape_util.h"
#include "tensorflow/compiler/xla/status_macros.h"
#include "tensorflow/compiler/xla/types.h"
#include "tensorflow/compiler/xla/util.h"
#include "tensorflow/compiler/xla/window_util.h"
#include "tensorflow/compiler/xla/xla_data.pb.h"
#include "tensorflow/core/lib/core/errors.h"
#include "tensorflow/core/lib/core/stringpiece.h"
#include "tensorflow/core/lib/math/math_util.h"
#include "tensorflow/core/lib/strings/str_util.h"
#include "tensorflow/core/lib/strings/strcat.h"
#include "tensorflow/core/lib/strings/stringprintf.h"
#include "tensorflow/core/platform/logging.h"
#include "tensorflow/core/platform/protobuf.h"

namespace xla {

namespace {

// Return the UnaryOperation proto enum value associated with the given HLO
// opcode.
UnaryOperation OpcodeToUnaryOperation(HloOpcode opcode) {
  switch (opcode) {
    case HloOpcode::kAbs:
      return UNOP_ABS;
    case HloOpcode::kCeil:
      return UNOP_CEIL;
    case HloOpcode::kCos:
      return UNOP_COS;
    case HloOpcode::kExp:
      return UNOP_EXP;
    case HloOpcode::kFloor:
      return UNOP_FLOOR;
    case HloOpcode::kImag:
      return UNOP_IMAG;
    case HloOpcode::kIsFinite:
      return UNOP_IS_FINITE;
    case HloOpcode::kLog:
      return UNOP_LOG;
    case HloOpcode::kNot:
      return UNOP_NOT;
    case HloOpcode::kNegate:
      return UNOP_NEGATE;
    case HloOpcode::kReal:
      return UNOP_REAL;
    case HloOpcode::kRoundNearestAfz:
      return UNOP_ROUND_NEAREST_AFZ;
    case HloOpcode::kSign:
      return UNOP_SIGN;
    case HloOpcode::kSin:
      return UNOP_SIN;
    case HloOpcode::kSort:
      return UNOP_SORT;
    case HloOpcode::kTanh:
      return UNOP_TANH;
    default:
      LOG(FATAL) << "Unhandled opcode for conversion to unary operation: "
                 << opcode;
  }
}

// Return the BinaryOperation proto enum value associated with the given HLO
// opcode.
BinaryOperation OpcodeToBinaryOperation(HloOpcode opcode) {
  switch (opcode) {
    case HloOpcode::kAtan2:
      return BINOP_ATAN2;
    case HloOpcode::kComplex:
      return BINOP_COMPLEX;
    case HloOpcode::kDot:
      return BINOP_DOT;
    case HloOpcode::kMultiply:
      return BINOP_MUL;
    case HloOpcode::kAdd:
      return BINOP_ADD;
    case HloOpcode::kSubtract:
      return BINOP_SUB;
    case HloOpcode::kDivide:
      return BINOP_DIV;
    case HloOpcode::kEq:
      return BINOP_EQ;
    case HloOpcode::kGe:
      return BINOP_GE;
    case HloOpcode::kGt:
      return BINOP_GT;
    case HloOpcode::kLe:
      return BINOP_LE;
    case HloOpcode::kLt:
      return BINOP_LT;
    case HloOpcode::kNe:
      return BINOP_NE;
    case HloOpcode::kMaximum:
      return BINOP_MAX;
    case HloOpcode::kMinimum:
      return BINOP_MIN;
    case HloOpcode::kPower:
      return BINOP_POW;
    case HloOpcode::kRemainder:
      return BINOP_REM;
    case HloOpcode::kOr:
      return BINOP_OR;
    case HloOpcode::kAnd:
      return BINOP_AND;
    case HloOpcode::kShiftLeft:
      return BINOP_SHIFT_LEFT;
    case HloOpcode::kShiftRightArithmetic:
      return BINOP_SHIFT_RIGHT_ARITHMETIC;
    case HloOpcode::kShiftRightLogical:
      return BINOP_SHIFT_RIGHT_LOGICAL;
    default:
      LOG(FATAL) << "unhandled opcode " << opcode;
  }
}

// Return the TernaryOperation proto enum value associated with the given HLO
// opcode.
TernaryOperation OpcodeToTernaryOperation(HloOpcode opcode) {
  switch (opcode) {
    case HloOpcode::kClamp:
      return TRIOP_CLAMP;
    case HloOpcode::kSelect:
      return TRIOP_SELECT;
    default:
      LOG(FATAL) << "unhandled opcode " << opcode;
  }
}

// Return the VariadicOperation proto enum value associated with the given HLO
// opcode.
VariadicOperation OpcodeToVariadicOperation(HloOpcode opcode) {
  switch (opcode) {
    case HloOpcode::kTuple:
      return VAROP_TUPLE;
    default:
      LOG(FATAL) << "unhandled opcode " << opcode;
  }
}

// Returns true if no element is present in slice more than once.
bool AllUnique(tensorflow::gtl::ArraySlice<int64> slice) {
  return std::set<int64>(slice.begin(), slice.end()).size() == slice.size();
}

tensorflow::Status ExpectNotTupleOrOpaque(const Shape& shape,
                                          tensorflow::StringPiece op_type) {
  if (ShapeUtil::IsTuple(shape)) {
    return InvalidArgument("Expected non-tuple argument for %s. Got: %s",
                           op_type.ToString().c_str(),
                           ShapeUtil::HumanString(shape).c_str());
  } else if (ShapeUtil::IsOpaque(shape)) {
    return InvalidArgument("Expected non-opaque argument for %s. Got: %s",
                           op_type.ToString().c_str(),
                           ShapeUtil::HumanString(shape).c_str());
  } else {
    return tensorflow::Status::OK();
  }
}

tensorflow::Status VerifyReducerShape(const ProgramShape& reducer_shape,
                                      const Shape& init_value_shape,
                                      const PrimitiveType& input_element_type) {
  if (reducer_shape.parameters_size() != 2) {
    return InvalidArgument(
        "Reduction function must take 2 parameters, but "
        "takes %d parameter(s).",
        reducer_shape.parameters_size());
  }

  const Shape& accumulator_shape = reducer_shape.result();
  if (ShapeUtil::Rank(accumulator_shape) != 0) {
    return Unimplemented(
        "Reduction function currently must have rank-0 result.");
  }

  // Check that the accumulator can be passed in as the first argument.
  // Note: comparing here and below with Compatible since we don't care about
  // layout in scalars - see b/26668201 for a longer-term vision.
  if (!ShapeUtil::Compatible(accumulator_shape, reducer_shape.parameters(0))) {
    return InvalidArgument(
        "Reduction function's first parameter shape differs from the "
        "result shape: %s vs %s",
        ShapeUtil::HumanString(reducer_shape.parameters(0)).c_str(),
        ShapeUtil::HumanString(accumulator_shape).c_str());
  }

  // Check that init_value's shape is suitable for reducer_shape.
  if (!ShapeUtil::Compatible(accumulator_shape, init_value_shape)) {
    return InvalidArgument(
        "Reduction function's accumulator shape differs from the "
        "init_value shape: %s vs %s",
        ShapeUtil::HumanString(accumulator_shape).c_str(),
        ShapeUtil::HumanString(init_value_shape).c_str());
  }

  // Check that the inputs can be passed in as the second argument.
  const Shape& input_element_shape =
      ShapeUtil::MakeShape(input_element_type, {});
  if (!ShapeUtil::Compatible(input_element_shape,
                             reducer_shape.parameters(1))) {
    return InvalidArgument(
        "Reduction function's second parameter shape differs from the "
        "input type element type: %s vs %s",
        ShapeUtil::HumanString(reducer_shape.parameters(1)).c_str(),
        ShapeUtil::HumanString(input_element_shape).c_str());
  }

  // Currently the accumulator and inputs must be the same type,
  // though that restriction could be relaxed.
  if (!ShapeUtil::Compatible(accumulator_shape, reducer_shape.parameters(1))) {
    return InvalidArgument(
        "Reduction function's second parameter shape currently must "
        "match the result shape. Got %s vs %s",
        ShapeUtil::HumanString(reducer_shape.parameters(1)).c_str(),
        ShapeUtil::HumanString(accumulator_shape).c_str());
  }

  return tensorflow::Status::OK();
}

StatusOr<Shape> InferWindowOutputShape(const Shape& base_shape,
                                       const Window& window,
                                       PrimitiveType element_type,
                                       bool allow_negative_padding) {
  if (window.dimensions_size() != ShapeUtil::Rank(base_shape)) {
    return InvalidArgument(
        "Window has dimension %d but base shape has dimension %lld.",
        window.dimensions_size(), ShapeUtil::Rank(base_shape));
  }

  std::vector<int64> output_dimensions(window.dimensions_size());
  for (int64 i = 0; i < window.dimensions_size(); ++i) {
    const auto& dim = window.dimensions(i);
    if (dim.size() <= 0) {
      return InvalidArgument("Window has a non-positive dimension. Window: %s",
                             window.DebugString().c_str());
    }
    if (dim.stride() <= 0) {
      return InvalidArgument("Window has a non-positive stride. Window: %s",
                             window.DebugString().c_str());
    }
    if (!allow_negative_padding && dim.padding_low() < 0) {
      return InvalidArgument("Window has a negative low padding. Window: %s",
                             window.DebugString().c_str());
    }
    if (!allow_negative_padding && dim.padding_high() < 0) {
      return InvalidArgument("Window has a negative high padding. Window: %s",
                             window.DebugString().c_str());
    }
    if (dim.base_dilation() < 1) {
      return InvalidArgument(
          "Window has a non-positive base area dilation factor. Window: %s",
          window.DebugString().c_str());
    }
    if (dim.window_dilation() < 1) {
      return InvalidArgument(
          "Window has a non-positive window dilation factor. Window: %s",
          window.DebugString().c_str());
    }

    const int64 dilated_base = window_util::DilatedBound(
        ShapeUtil::GetDimension(base_shape, i), dim.base_dilation());
    const int64 padded_dilated_base =
        dim.padding_low() + dilated_base + dim.padding_high();
    const int64 dilated_window =
        window_util::DilatedBound(dim.size(), dim.window_dilation());

    output_dimensions[i] = window_util::StridedBound(
        padded_dilated_base, dilated_window, dim.stride());
  }

  return ShapeUtil::MakeShape(element_type, output_dimensions);
}

}  // namespace

/* static */ StatusOr<Shape> ShapeInference::InferUnaryOpShape(
    HloOpcode opcode, const HloInstruction* operand) {
  // There is no copy operation at the proto level, so handle copy explicitly.
  if (opcode == HloOpcode::kCopy) {
    return operand->shape();
  }

  return InferUnaryOpShape(OpcodeToUnaryOperation(opcode), operand->shape());
}

/* static */ StatusOr<Shape> ShapeInference::InferUnaryOpShape(
    UnaryOperation operation, const Shape& arg) {
  TF_RETURN_IF_ERROR(ExpectNotTupleOrOpaque(arg, "operand of unary operation"));

  TF_DCHECK_OK(ShapeUtil::ValidateShapeWithOptionalLayout(arg));
  switch (operation) {
    case UNOP_FLOOR:
    case UNOP_CEIL:
      if (!ShapeUtil::ElementIsFloating(arg)) {
        return InvalidArgument(
            "expected element type in shape to be floating for floor/ceil "
            "operation; got %s",
            PrimitiveType_Name(arg.element_type()).c_str());
      }
      return arg;
    case UNOP_COS:
    case UNOP_SIN:
    case UNOP_EXP:
    case UNOP_LOG:
    case UNOP_TANH:
      if (!ShapeUtil::ElementIsFloating(arg) &&
          !ShapeUtil::ElementIsComplex(arg)) {
        return InvalidArgument(
            "expected element type in shape to be floating or complex for "
            "sin/cos/exp/log/tanh operation; got %s",
            PrimitiveType_Name(arg.element_type()).c_str());
      }
      return arg;
    case UNOP_REAL:
    case UNOP_IMAG:
      if (!ShapeUtil::ElementIsComplex(arg)) {
        return InvalidArgument(
            "expected element type in shape to be complex for real/imag "
            "operation; got %s",
            PrimitiveType_Name(arg.element_type()).c_str());
      }
      return ShapeUtil::ChangeElementType(arg, F32);
    case UNOP_ABS:
      if (ShapeUtil::ElementIsComplex(arg)) {
        return ShapeUtil::ChangeElementType(
            arg, primitive_util::ComplexComponentType(arg.element_type()));
      }
      return arg;
    case UNOP_NEGATE:
    case UNOP_ROUND_NEAREST_AFZ:
    case UNOP_SIGN:
    case UNOP_SORT:
      return arg;

    case UNOP_NOT:
      if (arg.element_type() != PRED &&
          !primitive_util::IsIntegralType(arg.element_type())) {
        return InvalidArgument(
            "expected pred or an integral element type in argument to not "
            "operation; got %s",
            PrimitiveType_Name(arg.element_type()).c_str());
      }
      return arg;

    case UNOP_IS_FINITE:
      if (!ShapeUtil::ElementIsFloating(arg)) {
        return InvalidArgument(
            "expected element type in shape to be floating point for IsFinite "
            "operation; got %s",
            PrimitiveType_Name(arg.element_type()).c_str());
      }
      return ShapeUtil::ChangeElementType(arg, PRED);

    default:
      return InvalidArgument(
          "Unknown operation for unary shape inference: \"%s\".",
          UnaryOperation_Name(operation).c_str());
  }
}

/* static */ StatusOr<Shape> ShapeInference::InferConcatOpShape(
    tensorflow::gtl::ArraySlice<const Shape*> arg_shapes,
    const int64 dimension) {
  if (arg_shapes.empty()) {
    return InvalidArgument("Concatenate expects at least one argument");
  }
  if (dimension < 0 || dimension >= ShapeUtil::Rank(*arg_shapes[0])) {
    return InvalidArgument("dimension to concatenate along out of bounds: %lld",
                           dimension);
  }
  const Shape* arg_shape = nullptr;
  for (const Shape* shape : arg_shapes) {
    TF_RETURN_IF_ERROR(
        ExpectNotTupleOrOpaque(*shape, "operand of concatenation"));
    if (!arg_shape) {
      arg_shape = shape;
      continue;
    }
    if (ShapeUtil::Rank(*arg_shape) != ShapeUtil::Rank(*shape)) {
      return InvalidArgument(
          "Cannot concatenate arrays with different ranks: %lld (%s) vs %lld "
          "(%s)",
          ShapeUtil::Rank(*arg_shape),
          ShapeUtil::HumanString(*arg_shape).c_str(), ShapeUtil::Rank(*shape),
          ShapeUtil::HumanString(*shape).c_str());
    }
    if (arg_shape->element_type() != shape->element_type()) {
      return InvalidArgument(
          "cannot concatenate arrays with different element types: %s vs %s",
          PrimitiveType_Name(arg_shape->element_type()).c_str(),
          PrimitiveType_Name(shape->element_type()).c_str());
    }
    for (int64 dimension_number = 0;
         dimension_number < ShapeUtil::Rank(*arg_shape); ++dimension_number) {
      if (arg_shape->dimensions(dimension_number) !=
          shape->dimensions(dimension_number)) {
        if (dimension_number == dimension) {
          continue;  // It's okay to differ in the dimension we're
                     // concatenating.
        }
        return InvalidArgument(
            "cannot concatenate arrays that differ in dimensions other than "
            "the one being concatenated (the other array dimensions must be "
            "the same): %s vs %s in dimension %lld",
            ShapeUtil::HumanString(*arg_shape).c_str(),
            ShapeUtil::HumanString(*shape).c_str(), dimension);
      }
    }
  }

  std::vector<int64> new_dimensions(arg_shape->dimensions().begin(),
                                    arg_shape->dimensions().end());
  for (size_t i = 1; i < arg_shapes.size(); ++i) {
    new_dimensions[dimension] += arg_shapes[i]->dimensions(dimension);
  }
  return ShapeUtil::MakeShape(arg_shape->element_type(), new_dimensions);
}

/* static */ StatusOr<Shape> ShapeInference::InferConvertShape(
    const Shape& operand_shape, PrimitiveType new_element_type) {
  if (ShapeUtil::IsTuple(operand_shape) || new_element_type == TUPLE) {
    // Note: we may want to support tuple conversions via this operation in the
    // future, by recursing into the tuple elements to check all sub-conversions
    // are valid. For now we just reject them, though.
    return InvalidArgument(
        "cannot convert from or to tuple type; requested conversion: %s => %s",
        ShapeUtil::HumanString(operand_shape).c_str(),
        PrimitiveType_Name(new_element_type).c_str());
  }

  return ShapeUtil::ChangeElementType(operand_shape, new_element_type);
}

/* static */ StatusOr<Shape> ShapeInference::InferReducePrecisionShape(
    const Shape& operand_shape, const int exponent_bits,
    const int mantissa_bits) {
  if (!ShapeUtil::ElementIsFloating(operand_shape)) {
    return InvalidArgument(
        "expected element type in shape to be floating point for "
        "ReducePrecision operation; got %s",
        PrimitiveType_Name(operand_shape.element_type()).c_str());
  }
  if (exponent_bits < 1) {
    // One exponent bit is necessary to distinguish 0 from infinity.  Having
    // no exponent bits doesn't produce a sensible number, so we require at
    // least one.
    return InvalidArgument("expected exponent_bits >= 1; got %d",
                           exponent_bits);
  }
  if (mantissa_bits < 0) {
    // A number with no mantissa bits is still meaningful, however.
    return InvalidArgument("expected non-negative mantissa_bits; got %d",
                           mantissa_bits);
  }
  return operand_shape;
}

/* static */ StatusOr<Shape> ShapeInference::InferPadShape(
    const Shape& operand_shape, const Shape& padding_value_shape,
    const PaddingConfig& padding_config) {
  if (ShapeUtil::IsTuple(operand_shape)) {
    return InvalidArgument(
        "pad operation does not support tuple-shape operands");
  }
  if (!ShapeUtil::IsScalar(padding_value_shape)) {
    return InvalidArgument(
        "pad operation does not support non-scalar padding values");
  }
  if (ShapeUtil::Rank(operand_shape) != padding_config.dimensions_size()) {
    return InvalidArgument(
        "The rank of the operand and the padding configuration do not match: "
        "%s vs %s",
        ShapeUtil::HumanString(operand_shape).c_str(),
        padding_config.ShortDebugString().c_str());
  }
  if (operand_shape.element_type() != padding_value_shape.element_type()) {
    return InvalidArgument(
        "the element types of the operands to pad do not match");
  }
  std::vector<int64> dimensions(ShapeUtil::Rank(operand_shape));
  for (int64 i = 0; i < operand_shape.dimensions_size(); ++i) {
    dimensions[i] = operand_shape.dimensions(i) +
                    padding_config.dimensions(i).edge_padding_low() +
                    padding_config.dimensions(i).edge_padding_high() +
                    std::max<int64>(operand_shape.dimensions(i) - 1, 0LL) *
                        padding_config.dimensions(i).interior_padding();
  }
  return ShapeUtil::MakeShape(operand_shape.element_type(), dimensions);
}

/* static */ StatusOr<Shape> ShapeInference::InferDotOpShape(const Shape& lhs,
                                                             const Shape& rhs) {
  TF_RETURN_IF_ERROR(ExpectNotTupleOrOpaque(lhs, "lhs of dot"));
  TF_RETURN_IF_ERROR(ExpectNotTupleOrOpaque(rhs, "rhs of dot"));

  auto fail = [lhs, rhs](const string& addendum) -> Status {
    string message = tensorflow::strings::Printf(
        "cannot infer shape for dot operation: %s <dot> %s",
        ShapeUtil::HumanString(lhs).c_str(),
        ShapeUtil::HumanString(rhs).c_str());
    if (!addendum.empty()) {
      message += ": " + addendum;
    }
    return InvalidArgument("%s", message.c_str());
  };

  // Check if both element types are the same.
  if (lhs.element_type() != rhs.element_type()) {
    return fail("element types do not match");
  }

  if (ShapeUtil::Rank(lhs) < 1 || ShapeUtil::Rank(lhs) > 2 ||
      ShapeUtil::Rank(rhs) < 1 || ShapeUtil::Rank(rhs) > 2) {
    return fail("dot only supports rank 1 or 2");
  }

  // Determine the index of the contracted dimensions for input tensors.
  // dimensions -1 of lhs and dimension 0 of rhs are contracted.
  int64 lhs_contracted_dimension = ShapeUtil::GetDimensionNumber(lhs, -1);
  int64 rhs_contracted_dimension = 0;

  // Check if the contracted dimension sizes are the same.
  if ((lhs_contracted_dimension < ShapeUtil::Rank(lhs) &&
       rhs_contracted_dimension < ShapeUtil::Rank(rhs)) &&
      lhs.dimensions(lhs_contracted_dimension) !=
          rhs.dimensions(rhs_contracted_dimension)) {
    return fail("contracted dimensions mismatch");
  }

  // The ranks of lhs and rhs are decremented by 1 respectively due to the
  // contraction, and added for the rank of the result. When an input tensor is
  // a scalar, its contribution to the rank of the result is 0.
  // Generate the result dimensions in order, rhs dimensions followed by lhs
  // dimensions except the contracted dimensions.
  std::vector<int64> dimensions;
  for (int64 i = 0; i < ShapeUtil::Rank(lhs); i++) {
    if (i != lhs_contracted_dimension) {
      dimensions.push_back(lhs.dimensions(i));
    }
  }
  for (int64 i = 0; i < ShapeUtil::Rank(rhs); i++) {
    if (i != rhs_contracted_dimension) {
      dimensions.push_back(rhs.dimensions(i));
    }
  }
  Shape result = ShapeUtil::MakeShape(lhs.element_type(), dimensions);

  TF_DCHECK_OK(ShapeUtil::ValidateShapeWithOptionalLayout(result));
  VLOG(2) << "inferred dot shape: " << ShapeUtil::HumanString(result);
  return result;
}

/* static */ StatusOr<Shape>
ShapeInference::InferDegenerateDimensionBroadcastShape(
    BinaryOperation operation, const Shape& lhs, const Shape& rhs) {
  TF_RET_CHECK(ShapeUtil::Rank(lhs) == ShapeUtil::Rank(rhs));

  // The shapes have to be compatible. That is, if some dimension d has a
  // different size in the two shapes, one of them has to be 1 (a "degenerate"
  // dimension). In that case, the output shape has the non-1 dimension size
  // from the lhs/rhs pair in every index.
  std::vector<int64> output_dimensions(ShapeUtil::Rank(lhs));
  for (int64 i = 0; i < ShapeUtil::Rank(lhs); ++i) {
    if (lhs.dimensions(i) == rhs.dimensions(i)) {
      output_dimensions[i] = lhs.dimensions(i);
    } else if (lhs.dimensions(i) == 1) {
      output_dimensions[i] = rhs.dimensions(i);
    } else if (rhs.dimensions(i) == 1) {
      output_dimensions[i] = lhs.dimensions(i);
    } else {
      return InvalidArgument("binary op %s with incompatible shapes: %s and %s",
                             BinaryOperation_Name(operation).c_str(),
                             ShapeUtil::HumanString(lhs).c_str(),
                             ShapeUtil::HumanString(rhs).c_str());
    }
  }
  return ShapeUtil::MakeShape(lhs.element_type(), output_dimensions);
}

/* static */ StatusOr<Shape> ShapeInference::InferInDimBroadcastShape(
    BinaryOperation operation, const Shape& smaller_shape,
    const Shape& larger_shape,
    tensorflow::gtl::ArraySlice<int64> broadcast_dimensions) {
  if (broadcast_dimensions.empty() && !ShapeUtil::IsScalar(smaller_shape)) {
    // Reject "magic" inference for binops on different shapes, requiring
    // the user to provide an explicit broadcast dimension in this case.
    // See b/25177275 for more details.
    return InvalidArgument("automatic shape inference not supported: %s and %s",
                           ShapeUtil::HumanString(smaller_shape).c_str(),
                           ShapeUtil::HumanString(larger_shape).c_str());
  } else if (broadcast_dimensions.size() != ShapeUtil::Rank(smaller_shape)) {
    return InvalidArgument(
        "size of broadcast_dimensions has to match lower-rank operand's "
        "rank; "
        " lower-rank operand's rank is %lld, size of broadcast_dimensions is "
        "%zu",
        ShapeUtil::Rank(smaller_shape), broadcast_dimensions.size());
  }

  // broadcast_dimensions is a sequence of dimensions; its length is equal to
  // the rank of the lower-rank operand. The lower-rank operand's dimensions
  // have to be compatible with the higher-rank operand's dimensions at indices
  // specified by broadcast_dimensions. Here compatible means the dimension
  // sizes are equal or in one of the shapes the dimension size is
  // one. Examples:
  //
  // smaller_shape   larger_shape   broadcast_dimensions   output_shape
  //   []              [2, 3]          {}                    [2, 3]
  //   [3]             [4, 3]          {1}                   [4, 3]
  //   [2, 3]          [2, 3, 4]       {0, 1}                [2, 3, 4]
  //   [2, 1]          [2, 3, 4]       {0, 2}                [2, 3, 1]
  //   [2, 3]          [2, 1, 4]       {0, 1}                [2, 3, 4]
  //
  // The column output_shape may not be the final shape of the XLA
  // operation. After the "InDim" broadcasting implemented in this function
  // expands the rank, degenerate-dimension broadcasting (implemented in
  // InferDegenerateDimensionBroadcastShape) broadcasts dimensions of size one
  // up to match the dimension size of the other operand. For example, consider
  // the row in the table above with a smaller_shape of [2, 1]. The shape
  // returned by this function is [2, 3, 1] (output_shape) however, the result
  // shape of the XLA operation is [2, 3, 4] after degenerate-dimension
  // broadcasting.
  //
  // Invalid broadcasts:
  //
  // smaller_shape=[3], larger_shape=[4, 3], broadcast_dimensions={0}
  // Reason: Dimension zero** of larger_shape (size 4) is not compatible with
  //   dimension zero of smaller_shape(size 3). **Zero here comes from the value
  //   in broadcast_dimensions.
  //
  // smaller_shape=[2, 1], larger_shape=[2, 3, 4], broadcast_dimensions={1, 2}
  // Reason: Dimension one of larger_shape (size 3) is not compatible with
  //   dimension zero of smaller_shape(size 2)

  // The output shape is initially the larger_shape. Sizes of dimensions
  // specified in broadcast_dimensions are then changed to match the
  // corresponding dimension size in smaller_shape.
  Shape output_shape(larger_shape);

  for (int i = 0; i < smaller_shape.dimensions_size(); ++i) {
    int64 dimension_to_match = broadcast_dimensions.at(i);
    if (dimension_to_match < 0) {
      return InvalidArgument(
          "broadcast dimension number (%lld) cannot be negative",
          dimension_to_match);
    }
    if (dimension_to_match >= larger_shape.dimensions_size()) {
      return InvalidArgument(
          "broadcast dimension number (%lld) too large; higher-rank "
          "operand has rank %d",
          dimension_to_match, larger_shape.dimensions_size());
    }
    int64 small_dimension_size = smaller_shape.dimensions(i);
    int64 large_dimension_size = larger_shape.dimensions(dimension_to_match);
    // Dimension sizes must be compatible: match or be degenerate (degenerate
    // case is handled by degenerate dimension broadcasting which occurs after
    // InDim broadcasting).
    if (small_dimension_size != large_dimension_size &&
        small_dimension_size != 1 && large_dimension_size != 1) {
      return InvalidArgument(
          "broadcast dimension %d mismatch: %lld != %lld; %s and %s", i,
          small_dimension_size, large_dimension_size,
          ShapeUtil::HumanString(smaller_shape).c_str(),
          ShapeUtil::HumanString(larger_shape).c_str());
    }
    // Make sure the broadcast dimensions are listed in a strictly increasing
    // order.
    if (i > 0 && broadcast_dimensions.at(i - 1) >= dimension_to_match) {
      return InvalidArgument(
          "broadcast dimensions order is wrong: %lld comes after %lld",
          dimension_to_match, broadcast_dimensions.at(i - 1));
    }

    output_shape.set_dimensions(dimension_to_match, small_dimension_size);
  }

  return output_shape;
}

/* static */ StatusOr<Shape> ShapeInference::InferElementwiseBinaryOpShape(
    BinaryOperation operation, const Shape& lhs, const Shape& rhs,
    tensorflow::gtl::ArraySlice<int64> broadcast_dimensions) {
  TF_RETURN_IF_ERROR(
      ExpectNotTupleOrOpaque(lhs, "lhs of elementwise binary operation"));
  TF_RETURN_IF_ERROR(
      ExpectNotTupleOrOpaque(rhs, "rhs of elementwise binary operation"));

  if (!ShapeUtil::SameElementType(lhs, rhs)) {
    return InvalidArgument(
        "binary op %s with different element types: %s and %s",
        BinaryOperation_Name(operation).c_str(),
        ShapeUtil::HumanString(lhs).c_str(),
        ShapeUtil::HumanString(rhs).c_str());
  }

  if (ShapeUtil::Rank(lhs) == ShapeUtil::Rank(rhs)) {
    std::vector<int64> identity_dims(ShapeUtil::Rank(lhs));
    std::iota(identity_dims.begin(), identity_dims.end(), 0);
    if (!broadcast_dimensions.empty() &&
        broadcast_dimensions != identity_dims) {
      return InvalidArgument(
          "broadcast dimensions field must either be not set or be the "
          "identity on binary operations with operands of the same rank");
    }
  }

  if (ShapeUtil::Compatible(lhs, rhs)) {
    // If the shapes are the same other than layout, the output shape is the
    // same (elementwise op).
    return lhs;
  }

  if (ShapeUtil::Rank(lhs) == ShapeUtil::Rank(rhs)) {
    return InferDegenerateDimensionBroadcastShape(operation, lhs, rhs);
  } else {
    // Ranks do not match, so perform InDim broadcasting using
    // broadcast_dimensions. Scalar broadcasting is a special case of this.
    const Shape& larger_shape =
        ShapeUtil::Rank(lhs) > ShapeUtil::Rank(rhs) ? lhs : rhs;
    const Shape& smaller_shape =
        ShapeUtil::Rank(lhs) > ShapeUtil::Rank(rhs) ? rhs : lhs;

    // After InDim broadcasting, perform degenerate dimensions broadcasting.
    TF_ASSIGN_OR_RETURN(
        Shape indim_broadcast_shape,
        InferInDimBroadcastShape(operation, smaller_shape, larger_shape,
                                 broadcast_dimensions));

    return InferDegenerateDimensionBroadcastShape(
        operation, indim_broadcast_shape, larger_shape);
  }
}

/* static */ StatusOr<Shape> ShapeInference::InferBinaryOpShape(
    HloOpcode opcode, const HloInstruction* lhs, const HloInstruction* rhs) {
  return InferBinaryOpShape(OpcodeToBinaryOperation(opcode), lhs->shape(),
                            rhs->shape(), /*broadcast_dimensions=*/{});
}

/* static */ StatusOr<Shape> ShapeInference::InferBinaryOpShape(
    BinaryOperation operation, const Shape& lhs, const Shape& rhs,
    tensorflow::gtl::ArraySlice<int64> broadcast_dimensions) {
  VLOG(2) << tensorflow::strings::Printf(
      "inferring shape for <%s>(%s, %s) with broadcast_dimensions={%s}",
      BinaryOperation_Name(operation).c_str(),
      ShapeUtil::HumanString(lhs).c_str(), ShapeUtil::HumanString(rhs).c_str(),
      tensorflow::str_util::Join(broadcast_dimensions, ", ").c_str());
  TF_DCHECK_OK(ShapeUtil::ValidateShapeWithOptionalLayout(lhs));
  TF_DCHECK_OK(ShapeUtil::ValidateShapeWithOptionalLayout(rhs));

  TF_RETURN_IF_ERROR(ExpectNotTupleOrOpaque(
      lhs, tensorflow::strings::StrCat("lhs of binary operation ",
                                       BinaryOperation_Name(operation))));
  TF_RETURN_IF_ERROR(ExpectNotTupleOrOpaque(
      rhs, tensorflow::strings::StrCat("rhs of binary operation ",
                                       BinaryOperation_Name(operation))));
  switch (operation) {
    case BINOP_DOT:
      return InferDotOpShape(lhs, rhs);
    case BINOP_MAX:
    case BINOP_MIN:
    case BINOP_SUB:
    case BINOP_ADD:
    case BINOP_ATAN2:
    case BINOP_POW:
    case BINOP_DIV:
    case BINOP_REM:
    case BINOP_MUL:
    case BINOP_SHIFT_LEFT:
    case BINOP_SHIFT_RIGHT_ARITHMETIC:
    case BINOP_SHIFT_RIGHT_LOGICAL:
      return InferElementwiseBinaryOpShape(operation, lhs, rhs,
                                           broadcast_dimensions);

    case BINOP_COMPLEX: {
      if (!ShapeUtil::ElementIsFloating(lhs)) {
        return InvalidArgument(
            "expected element type in shape to be floating for complex compose "
            "operation; got %s",
            PrimitiveType_Name(lhs.element_type()).c_str());
      }
      TF_ASSIGN_OR_RETURN(const Shape& shape,
                          InferElementwiseBinaryOpShape(operation, lhs, rhs,
                                                        broadcast_dimensions));
      if (lhs.element_type() == F32) {
        return ShapeUtil::ChangeElementType(shape, C64);
      } else {
        return Unimplemented("complex component type not supported");
      }
    }
    case BINOP_AND:
    case BINOP_OR:
      if (lhs.element_type() != PRED &&
          !primitive_util::IsIntegralType(lhs.element_type())) {
        return InvalidArgument(
            "expected pred or integral type in argument to and/or operation; "
            "got %s",
            PrimitiveType_Name(lhs.element_type()).c_str());
      }
      return InferElementwiseBinaryOpShape(operation, lhs, rhs,
                                           broadcast_dimensions);
    case BINOP_EQ:
    case BINOP_GE:
    case BINOP_GT:
    case BINOP_LE:
    case BINOP_LT:
    case BINOP_NE: {
      TF_ASSIGN_OR_RETURN(const Shape& shape,
                          InferElementwiseBinaryOpShape(operation, lhs, rhs,
                                                        broadcast_dimensions));
      return ShapeUtil::ChangeElementType(shape, PRED);
    }
    default:
      return Unimplemented(
          "not yet implemented; infer binary op shape: %s; lhs: %s; rhs: %s",
          BinaryOperation_Name(operation).c_str(),
          lhs.ShortDebugString().c_str(), rhs.ShortDebugString().c_str());
  }
}

/* static */ StatusOr<Shape> ShapeInference::InferTernaryOpShape(
    HloOpcode opcode, const HloInstruction* lhs, const HloInstruction* rhs,
    const HloInstruction* ehs) {
  return InferTernaryOpShape(OpcodeToTernaryOperation(opcode), lhs->shape(),
                             rhs->shape(), ehs->shape());
}

/* static */ StatusOr<Shape> ShapeInference::InferTernaryOpShape(
    TernaryOperation operation, const Shape& lhs, const Shape& rhs,
    const Shape& ehs) {
  TF_DCHECK_OK(ShapeUtil::ValidateShapeWithOptionalLayout(lhs));
  TF_DCHECK_OK(ShapeUtil::ValidateShapeWithOptionalLayout(rhs));
  TF_DCHECK_OK(ShapeUtil::ValidateShapeWithOptionalLayout(ehs));
  switch (operation) {
    case TRIOP_CLAMP:
      return InferClampShape(lhs, rhs, ehs);
    case TRIOP_SELECT:
      return InferSelectShape(lhs, rhs, ehs);
    default:
      return InvalidArgument("unknown operation %s",
                             TernaryOperation_Name(operation).c_str());
  }
}

/* static */ StatusOr<Shape> ShapeInference::InferVariadicOpShape(
    HloOpcode opcode,
    tensorflow::gtl::ArraySlice<const HloInstruction*> operands) {
  std::vector<const Shape*> operand_shapes;
  for (const HloInstruction* operand : operands) {
    operand_shapes.push_back(&operand->shape());
  }
  return InferVariadicOpShape(OpcodeToVariadicOperation(opcode),
                              operand_shapes);
}

/* static */ StatusOr<Shape> ShapeInference::InferVariadicOpShape(
    VariadicOperation operation,
    tensorflow::gtl::ArraySlice<const Shape*> operand_shapes) {
  for (const Shape* shape : operand_shapes) {
    TF_DCHECK_OK(ShapeUtil::ValidateShapeWithOptionalLayout(*shape));
  }
  switch (operation) {
    case VAROP_TUPLE: {
      Shape result = ShapeUtil::MakeTupleShape({});
      for (const Shape* shape : operand_shapes) {
        ShapeUtil::AppendShapeToTuple(*shape, &result);
      }
      return result;
    }
    default:
      return InvalidArgument("unknown operation %s",
                             VariadicOperation_Name(operation).c_str());
  }
}

/* static */ StatusOr<Shape> ShapeInference::InferMapShape(
    tensorflow::gtl::ArraySlice<const Shape*> arg_shapes,
    const ProgramShape& to_apply,
    tensorflow::gtl::ArraySlice<int64> dimensions) {
  if (arg_shapes.empty()) {
    return InvalidArgument("Map expects at least one argument");
  }

  // All arguments must have the same shape.
  const Shape* arg_shape = arg_shapes[0];
  for (size_t i = 1; i < arg_shapes.size(); ++i) {
    TF_RETURN_IF_ERROR(
        ExpectNotTupleOrOpaque(*arg_shapes[i], "operand of map"));

    if (ShapeUtil::Compatible(*arg_shapes[i], *arg_shape)) {
      continue;
    }
    if (!ShapeUtil::IsTuple(*arg_shapes[i]) &&
        !ShapeUtil::IsTuple(*arg_shape) &&
        ShapeUtil::SameElementType(*arg_shapes[i], *arg_shape)) {
      if (ShapeUtil::IsScalar(*arg_shapes[i])) {
        continue;
      }
      if (ShapeUtil::IsScalar(*arg_shape)) {
        arg_shape = arg_shapes[i];
        continue;
      }
    }

    std::vector<string> pieces;
    for (const Shape* shape : arg_shapes) {
      pieces.push_back(ShapeUtil::HumanString(*shape));
    }
    return InvalidArgument(
        "Map operation requires all operands to have the same shape; got: "
        "%s",
        tensorflow::str_util::Join(pieces, ", ").c_str());
  }

  // Check that dimensions.size == arg_shape.dimensions_size() (we currently
  // only support mapping across all dimensions: i.e. scalar map functions).
  if (dimensions.size() != arg_shape->dimensions_size()) {
    return InvalidArgument(
        "Map applied to a subset of dimensions currently not supported: "
        "arg_dimension_size: %d, requested_map_dimensions_size: %zu",
        arg_shape->dimensions_size(), dimensions.size());
  }

  // Check that requested map dimensions numbers are monotonically increasing.
  for (int i = 0; i < dimensions.size(); ++i) {
    if (dimensions[i] != i) {
      return InvalidArgument(
          "Map requires monotonically increasing dimension numbers, found: %s ",
          tensorflow::str_util::Join(dimensions, ", ").c_str());
    }
  }

  // The applied function's arity equals the number of arguments.
  if (arg_shapes.size() != to_apply.parameters_size()) {
    return InvalidArgument(
        "Map applied function arity must match number of arguments; got: "
        "arity: %d, arguments: %zu",
        to_apply.parameters_size(), arg_shapes.size());
  }

  // The parameters should all be scalars, and the output too.
  const Shape& output_shape = to_apply.result();
  if (!ShapeUtil::IsScalar(output_shape)) {
    return InvalidArgument(
        "mapped computation's result has to be a scalar; "
        "got: %s",
        ShapeUtil::HumanString(output_shape).c_str());
  }

  for (int i = 0; i < to_apply.parameters_size(); ++i) {
    const Shape& parameter_shape = to_apply.parameters(i);

    if (!ShapeUtil::IsScalar(parameter_shape)) {
      return InvalidArgument(
          "mapped computation's parameter has to be a scalar; "
          "got parameter %d shape: %s",
          i, ShapeUtil::HumanString(parameter_shape).c_str());
    }

    if (parameter_shape.element_type() != arg_shape->element_type()) {
      return InvalidArgument(
          "mapped computation's parameter type has to match argument element "
          "type; got parameter %d shape: %s, argument shape: %s",
          i, ShapeUtil::HumanString(parameter_shape).c_str(),
          ShapeUtil::HumanString(*arg_shape).c_str());
    }
  }

  return ShapeUtil::MakeShape(output_shape.element_type(),
                              AsInt64Slice(arg_shape->dimensions()));
}

/* static */ StatusOr<Shape> ShapeInference::InferBatchNormTrainingShape(
    const Shape& operand_shape, const Shape& scale_shape,
    const Shape& offset_shape, int64 feature_index) {
  TF_RETURN_IF_ERROR(
      ExpectNotTupleOrOpaque(operand_shape, "operand of batch norm training"));
  TF_RETURN_IF_ERROR(ExpectNotTupleOrOpaque(
      offset_shape, "offset input of batch norm training"));
  TF_RETURN_IF_ERROR(ExpectNotTupleOrOpaque(
      scale_shape, "scale input of batch norm training"));

  TF_RET_CHECK(ShapeUtil::ValidateShapeWithOptionalLayout(operand_shape) ==
               tensorflow::Status::OK());
  TF_RET_CHECK(ShapeUtil::ValidateShapeWithOptionalLayout(offset_shape) ==
               tensorflow::Status::OK());
  TF_RET_CHECK(ShapeUtil::ValidateShapeWithOptionalLayout(scale_shape) ==
               tensorflow::Status::OK());

  if (feature_index >= ShapeUtil::Rank(operand_shape)) {
    return InvalidArgument(
        "Expected feature_index of batch-norm-training to be "
        "smaller than the rank of operand_shape; "
        "got feature_index %lld, and rank %lld",
        feature_index, ShapeUtil::Rank(operand_shape));
  }

  if (feature_index < 0) {
    return InvalidArgument(
        "Expected feature_index of batch-norm-training to "
        "be a non-negative number, got %lld",
        feature_index);
  }

  if (ShapeUtil::Rank(operand_shape) < 1) {
    return InvalidArgument(
        "Expected the rank of operand to "
        "batch-norm-training to be at least 1; got %lld",
        ShapeUtil::Rank(operand_shape));
  }

  if (ShapeUtil::Rank(offset_shape) != 1) {
    return InvalidArgument(
        "Offset input of batch-norm-training must have"
        " rank 1, but has rank %lld.",
        ShapeUtil::Rank(offset_shape));
  }

  if (ShapeUtil::Rank(scale_shape) != 1) {
    return InvalidArgument(
        "Scale input of batch-norm-training must have"
        " rank 1, but has rank %lld.",
        ShapeUtil::Rank(scale_shape));
  }

  if (!ShapeUtil::ElementIsFloating(operand_shape)) {
    return InvalidArgument(
        "The operand to batch-norm-training must have a floating point "
        "element type, but the shape is %s",
        PrimitiveType_Name(operand_shape.element_type()).c_str());
  }

  if (!ShapeUtil::SameElementType(offset_shape, operand_shape)) {
    return InvalidArgument(
        "The inputs should have the same element type for batch-norm-training, "
        "but the shape of offset factor is %s "
        "and the shape of operand is %s",
        PrimitiveType_Name(offset_shape.element_type()).c_str(),
        PrimitiveType_Name(operand_shape.element_type()).c_str());
  }

  if (!ShapeUtil::SameElementType(scale_shape, operand_shape)) {
    return InvalidArgument(
        "The inputs should have the same element type for batch-norm-training, "
        "but the shape of scale factor is %s "
        "and the shape of operand is %s",
        PrimitiveType_Name(scale_shape.element_type()).c_str(),
        PrimitiveType_Name(operand_shape.element_type()).c_str());
  }

  const int64 feature_count = operand_shape.dimensions(feature_index);
  Shape output_shape_for_mean_and_var =
      ShapeUtil::MakeShape(operand_shape.element_type(), {feature_count});

  if (ShapeUtil::GetDimension(offset_shape, 0) != feature_count) {
    return InvalidArgument(
        "The size of offset factor should be the same as feature count,"
        "but the size of offset factor is %lld "
        "and the feature count is %lld",
        ShapeUtil::GetDimension(offset_shape, 0), feature_count);
  }

  if (ShapeUtil::GetDimension(scale_shape, 0) != feature_count) {
    return InvalidArgument(
        "The size of scale factor should be the same as feature count,"
        "but the size of scale factor is %lld "
        "and the feature count is %lld",
        ShapeUtil::GetDimension(scale_shape, 0), feature_count);
  }

  return ShapeUtil::MakeTupleShape({operand_shape,
                                    output_shape_for_mean_and_var,
                                    output_shape_for_mean_and_var});
}

/* static */ StatusOr<Shape> ShapeInference::InferBatchNormInferenceShape(
    const Shape& operand_shape, const Shape& scale_shape,
    const Shape& offset_shape, const Shape& mean_shape,
    const Shape& variance_shape, int64 feature_index) {
  TF_RETURN_IF_ERROR(
      ExpectNotTupleOrOpaque(operand_shape, "operand of batch norm inference"));
  TF_RETURN_IF_ERROR(ExpectNotTupleOrOpaque(
      offset_shape, "offset input of batch norm inference"));
  TF_RETURN_IF_ERROR(ExpectNotTupleOrOpaque(
      scale_shape, "scale input of batch norm inference"));

  TF_RET_CHECK(ShapeUtil::ValidateShapeWithOptionalLayout(operand_shape) ==
               tensorflow::Status::OK());
  TF_RET_CHECK(ShapeUtil::ValidateShapeWithOptionalLayout(offset_shape) ==
               tensorflow::Status::OK());
  TF_RET_CHECK(ShapeUtil::ValidateShapeWithOptionalLayout(scale_shape) ==
               tensorflow::Status::OK());
  TF_RET_CHECK(ShapeUtil::ValidateShapeWithOptionalLayout(mean_shape) ==
               tensorflow::Status::OK());
  TF_RET_CHECK(ShapeUtil::ValidateShapeWithOptionalLayout(variance_shape) ==
               tensorflow::Status::OK());

  if (feature_index >= ShapeUtil::Rank(operand_shape)) {
    return InvalidArgument(
        "Expected feature_index of batch-norm-inference to be "
        "smaller than the rank of operand_shape; "
        "got feature_index %lld, and rank %lld",
        feature_index, ShapeUtil::Rank(operand_shape));
  }

  if (feature_index < 0) {
    return InvalidArgument(
        "Expected feature_index of batch-norm-inference to "
        "be a non-negative number, got %lld",
        feature_index);
  }

  if (ShapeUtil::Rank(operand_shape) < 1) {
    return InvalidArgument(
        "Expected the rank of operand to "
        "batch-norm-inference to be at least 1; got %lld",
        ShapeUtil::Rank(operand_shape));
  }

  if (ShapeUtil::Rank(offset_shape) != 1) {
    return InvalidArgument(
        "Offset input of batch-norm-inference must have"
        " rank 1, but has rank %lld.",
        ShapeUtil::Rank(offset_shape));
  }

  if (ShapeUtil::Rank(scale_shape) != 1) {
    return InvalidArgument(
        "Scale input of batch-norm-inference must have"
        " rank 1, but has rank %lld.",
        ShapeUtil::Rank(scale_shape));
  }

  if (!ShapeUtil::ElementIsFloating(operand_shape)) {
    return InvalidArgument(
        "The operand to batch-norm-inference must have a floating point "
        "element type, but the shape is %s",
        PrimitiveType_Name(operand_shape.element_type()).c_str());
  }

  if (!ShapeUtil::SameElementType(offset_shape, operand_shape)) {
    return InvalidArgument(
        "The inputs should have the same element type for "
        "batch-norm-inference, "
        "but the shape of offset factor is %s "
        "and the shape of operand is %s",
        PrimitiveType_Name(offset_shape.element_type()).c_str(),
        PrimitiveType_Name(operand_shape.element_type()).c_str());
  }

  if (!ShapeUtil::SameElementType(scale_shape, operand_shape)) {
    return InvalidArgument(
        "The inputs should have the same element type for "
        "batch-norm-inference, "
        "but the shape of scale factor is %s "
        "and the shape of operand is %s",
        PrimitiveType_Name(scale_shape.element_type()).c_str(),
        PrimitiveType_Name(operand_shape.element_type()).c_str());
  }

  if (!ShapeUtil::SameElementType(mean_shape, operand_shape)) {
    return InvalidArgument(
        "The inputs should have the same element type for "
        "batch-norm-inference, "
        "but the shape of mean is %s "
        "and the shape of operand is %s",
        PrimitiveType_Name(mean_shape.element_type()).c_str(),
        PrimitiveType_Name(operand_shape.element_type()).c_str());
  }

  if (!ShapeUtil::SameElementType(variance_shape, operand_shape)) {
    return InvalidArgument(
        "The inputs should have the same element type for "
        "batch-norm-inference, "
        "but the shape of variance is %s "
        "and the shape of operand is %s",
        PrimitiveType_Name(mean_shape.element_type()).c_str(),
        PrimitiveType_Name(variance_shape.element_type()).c_str());
  }

  const int64 feature_count = operand_shape.dimensions(feature_index);
  Shape output_shape_for_mean_and_var =
      ShapeUtil::MakeShape(operand_shape.element_type(), {feature_count});

  if (ShapeUtil::GetDimension(offset_shape, 0) != feature_count) {
    return InvalidArgument(
        "The size of offset factor should be the same as feature count,"
        "but the size of offset factor is %lld "
        "and the feature count is %lld",
        ShapeUtil::GetDimension(offset_shape, 0), feature_count);
  }

  if (ShapeUtil::GetDimension(scale_shape, 0) != feature_count) {
    return InvalidArgument(
        "The size of scale factor should be the same as feature count,"
        "but the size of scale factor is %lld "
        "and the feature count is %lld",
        ShapeUtil::GetDimension(scale_shape, 0), feature_count);
  }

  if (ShapeUtil::GetDimension(mean_shape, 0) != feature_count) {
    return InvalidArgument(
        "The size of mean should be the same as feature count,"
        "but the size of mean is %lld "
        "and the feature count is %lld",
        ShapeUtil::GetDimension(mean_shape, 0), feature_count);
  }

  if (ShapeUtil::GetDimension(variance_shape, 0) != feature_count) {
    return InvalidArgument(
        "The size of variance should be the same as feature count,"
        "but the size of variance is %lld "
        "and the feature count is %lld",
        ShapeUtil::GetDimension(variance_shape, 0), feature_count);
  }

  return operand_shape;
}

/* static */ StatusOr<Shape> ShapeInference::InferBatchNormGradShape(
    const Shape& operand_shape, const Shape& scale_shape,
    const Shape& mean_shape, const Shape& var_shape,
    const Shape& output_grad_shape, int64 feature_index) {
  TF_RETURN_IF_ERROR(
      ExpectNotTupleOrOpaque(operand_shape, "operand of batch norm grad"));
  TF_RETURN_IF_ERROR(
      ExpectNotTupleOrOpaque(scale_shape, "scale input of batch norm grad"));
  TF_RETURN_IF_ERROR(
      ExpectNotTupleOrOpaque(mean_shape, "mean input of batch norm grad"));
  TF_RETURN_IF_ERROR(
      ExpectNotTupleOrOpaque(var_shape, "var input of batch norm grad"));
  TF_RETURN_IF_ERROR(ExpectNotTupleOrOpaque(
      output_grad_shape, "output_grad input of batch norm grad"));

  TF_RETURN_IF_ERROR(ShapeUtil::ValidateShapeWithOptionalLayout(operand_shape));
  TF_RETURN_IF_ERROR(ShapeUtil::ValidateShapeWithOptionalLayout(mean_shape));
  TF_RETURN_IF_ERROR(ShapeUtil::ValidateShapeWithOptionalLayout(scale_shape));
  TF_RETURN_IF_ERROR(ShapeUtil::ValidateShapeWithOptionalLayout(var_shape));
  TF_RETURN_IF_ERROR(
      ShapeUtil::ValidateShapeWithOptionalLayout(output_grad_shape));

  if (feature_index >= ShapeUtil::Rank(operand_shape)) {
    return InvalidArgument(
        "Expected feature_index of batch-norm-grad to be "
        "smaller than the rank of operand_shape; "
        "got feature_index %lld, and rank %lld",
        feature_index, ShapeUtil::Rank(operand_shape));
  }

  if (ShapeUtil::Rank(operand_shape) != ShapeUtil::Rank(output_grad_shape)) {
    return InvalidArgument(
        "Expected operand_shape of batch-norm-grad to have the same rank as"
        " output_grad_shape; got rank(oprand_shape) %lld, and"
        " rank(output_grad_shape) %lld",
        ShapeUtil::Rank(operand_shape), ShapeUtil::Rank(output_grad_shape));
  }

  if (ShapeUtil::Rank(mean_shape) != 1) {
    return InvalidArgument(
        "Mean input of batch-norm-grad must have"
        " rank 1, but has rank %lld.",
        ShapeUtil::Rank(mean_shape));
  }

  if (ShapeUtil::Rank(scale_shape) != 1) {
    return InvalidArgument(
        "Scale input of batch-norm-grad must have"
        " rank 1, but has rank %lld.",
        ShapeUtil::Rank(scale_shape));
  }

  if (ShapeUtil::Rank(var_shape) != 1) {
    return InvalidArgument(
        "Var input of batch-norm-grad must have"
        " rank 1, but has rank %lld.",
        ShapeUtil::Rank(var_shape));
  }

  if (!ShapeUtil::ElementIsFloating(operand_shape)) {
    return InvalidArgument(
        "The operand to batch-norm-grad must have a floating point "
        "element type, but the shape is %s",
        PrimitiveType_Name(operand_shape.element_type()).c_str());
  }

  if (!ShapeUtil::ElementIsFloating(output_grad_shape)) {
    return InvalidArgument(
        "The output_grad to batch-norm-grad must have a floating point "
        "element type, but the shape is %s",
        PrimitiveType_Name(output_grad_shape.element_type()).c_str());
  }

  if (!ShapeUtil::SameElementType(output_grad_shape, operand_shape)) {
    return InvalidArgument(
        "The inputs should have the same element type for batch-norm-grad, "
        "but the element type of output_grad is %s "
        "and the element type of operand is %s",
        PrimitiveType_Name(output_grad_shape.element_type()).c_str(),
        PrimitiveType_Name(operand_shape.element_type()).c_str());
  }

  if (!ShapeUtil::SameElementType(scale_shape, operand_shape)) {
    return InvalidArgument(
        "The inputs should have the same element type for batch-norm-grad, "
        "but the element type of scale factor is %s "
        "and the element type of operand is %s",
        PrimitiveType_Name(scale_shape.element_type()).c_str(),
        PrimitiveType_Name(operand_shape.element_type()).c_str());
  }

  if (!ShapeUtil::SameElementType(mean_shape, operand_shape)) {
    return InvalidArgument(
        "The inputs should have the same element type for batch-norm-grad, "
        "but the element type of mean is %s "
        "and the element type of operand is %s",
        PrimitiveType_Name(mean_shape.element_type()).c_str(),
        PrimitiveType_Name(operand_shape.element_type()).c_str());
  }

  if (!ShapeUtil::SameElementType(var_shape, operand_shape)) {
    return InvalidArgument(
        "The inputs should have the same element type for batch-norm-grad, "
        "but the element type of mean is %s "
        "and the element type of operand is %s",
        PrimitiveType_Name(mean_shape.element_type()).c_str(),
        PrimitiveType_Name(operand_shape.element_type()).c_str());
  }

  const int64 feature_count = operand_shape.dimensions(feature_index);

  Shape feature_shape =
      ShapeUtil::MakeShape(operand_shape.element_type(), {feature_count});

  if (ShapeUtil::GetDimension(mean_shape, 0) != feature_count) {
    return InvalidArgument(
        "The size of mean should be the same as feature count,"
        "but the size of offset factor is %lld "
        "and the feature count is %lld",
        ShapeUtil::GetDimension(mean_shape, 0), feature_count);
  }

  if (ShapeUtil::GetDimension(scale_shape, 0) != feature_count) {
    return InvalidArgument(
        "The size of scale factor should be the same as feature count,"
        "but the size of scale factor is %lld "
        "and the feature count is %lld",
        ShapeUtil::GetDimension(scale_shape, 0), feature_count);
  }

  if (ShapeUtil::GetDimension(var_shape, 0) != feature_count) {
    return InvalidArgument(
        "The size of variance should be the same as feature count,"
        "but the size of variance is %lld "
        "and the feature count is %lld",
        ShapeUtil::GetDimension(var_shape, 0), feature_count);
  }

  // Verify operand_shape and output_grad_shape have same bounds.
  for (int64 i = 0; i < ShapeUtil::Rank(operand_shape); ++i) {
    if (ShapeUtil::GetDimension(operand_shape, i) !=
        ShapeUtil::GetDimension(output_grad_shape, i)) {
      return InvalidArgument(
          "The bounds of operand shape should be the same as output_grad's,"
          "but the bound of operand_shape at dimension %lld is %lld "
          "and the bound of output_grad_shape is %lld",
          i, ShapeUtil::GetDimension(operand_shape, i),
          ShapeUtil::GetDimension(output_grad_shape, i));
    }
  }

  return ShapeUtil::MakeTupleShape(
      {operand_shape, feature_shape, feature_shape});
}

/* static */ StatusOr<Shape> ShapeInference::InferConvolveShape(
    const Shape& lhs, const Shape& rhs, const Window& window,
    const ConvolutionDimensionNumbers& dnums) {
  TF_RETURN_IF_ERROR(ExpectNotTupleOrOpaque(lhs, "lhs of convolution"));
  TF_RETURN_IF_ERROR(ExpectNotTupleOrOpaque(rhs, "rhs of convolution"));

  if (!ShapeUtil::SameElementType(lhs, rhs)) {
    return InvalidArgument(
        "Convolution with different element types: %s and %s",
        ShapeUtil::HumanString(lhs).c_str(),
        ShapeUtil::HumanString(rhs).c_str());
  }
  if (dnums.spatial_dimensions_size() !=
      dnums.kernel_spatial_dimensions_size()) {
    return InvalidArgument(
        "Both arguments to convolution must have same number of dimensions.\n"
        "Window: %s",
        window.DebugString().c_str());
  }

  const int num_spatial_dims = dnums.spatial_dimensions_size();
  if (window.dimensions_size() != num_spatial_dims) {
    return InvalidArgument(
        "Window must have same number of dimensions as dimension numbers.\n"
        "Window: %s\nDimension numbers: %s",
        window.DebugString().c_str(), dnums.DebugString().c_str());
  }

  const int num_dims = num_spatial_dims + 2;
  if (ShapeUtil::Rank(lhs) != num_dims) {
    return InvalidArgument(
        "The LHS argument to a convolution should have rank %d.\n"
        "lhs: %s",
        num_dims, ShapeUtil::HumanString(lhs).c_str());
  }
  if (ShapeUtil::Rank(rhs) != num_dims) {
    return InvalidArgument(
        "The RHS argument to a convolution should have rank %d.\n"
        "lhs: %s",
        num_dims, ShapeUtil::HumanString(lhs).c_str());
  }
  TF_DCHECK_OK(ShapeUtil::ValidateShapeWithOptionalLayout(lhs));
  TF_DCHECK_OK(ShapeUtil::ValidateShapeWithOptionalLayout(rhs));

  // Verifies that the input and window dimensions are a permutation of
  // the dimension numbers.
  std::vector<int64> input_dnums(num_dims);
  input_dnums[0] = dnums.input_batch_dimension();
  input_dnums[1] = dnums.input_feature_dimension();
  std::copy(dnums.spatial_dimensions().begin(),
            dnums.spatial_dimensions().end(), input_dnums.begin() + 2);
  std::sort(input_dnums.begin(), input_dnums.end());

  std::vector<int64> window_dnums(num_dims);
  window_dnums[0] = dnums.kernel_input_feature_dimension();
  window_dnums[1] = dnums.kernel_output_feature_dimension();
  std::copy(dnums.kernel_spatial_dimensions().begin(),
            dnums.kernel_spatial_dimensions().end(), window_dnums.begin() + 2);
  std::sort(window_dnums.begin(), window_dnums.end());

  std::vector<int64> expected_dnums(num_dims);
  std::iota(expected_dnums.begin(), expected_dnums.end(), 0);

  const auto in_range = [num_dims](int64 i) { return 0 <= i && i < num_dims; };
  if (!std::all_of(input_dnums.begin(), input_dnums.end(), in_range) ||
      !std::all_of(window_dnums.begin(), window_dnums.end(), in_range)) {
    return InvalidArgument(
        "A dimension number is out of range in convolution: %s",
        dnums.DebugString().c_str());
  }

  if (input_dnums != expected_dnums) {
    return InvalidArgument(
        "Input dimensions of convolution must contain each dimension exactly "
        "once: %s",
        dnums.DebugString().c_str());
  }
  if (window_dnums != expected_dnums) {
    return InvalidArgument(
        "Window dimensions of convolution must contain each dimension exactly "
        "once: %s",
        dnums.DebugString().c_str());
  }

  std::vector<int64> input_spatial_dims(num_spatial_dims);
  for (int i = 0; i < num_spatial_dims; ++i) {
    input_spatial_dims[i] = lhs.dimensions(dnums.spatial_dimensions(i));
  }
  const int64 input_features = lhs.dimensions(dnums.input_feature_dimension());
  const int64 input_batch = lhs.dimensions(dnums.input_batch_dimension());

  std::vector<int64> kernel_spatial_dims(num_spatial_dims);
  for (int i = 0; i < num_spatial_dims; ++i) {
    kernel_spatial_dims[i] = rhs.dimensions(dnums.kernel_spatial_dimensions(i));
  }
  const int64 kernel_input_features =
      rhs.dimensions(dnums.kernel_input_feature_dimension());
  const int64 kernel_output_features =
      rhs.dimensions(dnums.kernel_output_feature_dimension());

  if (input_features != kernel_input_features) {
    return InvalidArgument(
        "Expected LHS feature dimension (value %lld) to match RHS "
        "input feature dimension (value %lld); got <conv>(%s, %s)\n"
        "Dimension numbers: {%s}",
        input_features, kernel_input_features,
        ShapeUtil::HumanString(lhs).c_str(),
        ShapeUtil::HumanString(rhs).c_str(), dnums.DebugString().c_str());
  }
  std::vector<int64> window_dims(num_spatial_dims);
  for (int i = 0; i < num_spatial_dims; ++i) {
    window_dims[i] = window.dimensions(i).size();
  }
  if (kernel_spatial_dims != window_dims) {
    return InvalidArgument(
        "Window dimensions do not match RHS shape:\n\t"
        "RHS shape: %s\n\t"
        "Window: {%s}\n\t"
        "Dimension numbers: {%s}",
        ShapeUtil::HumanString(rhs).c_str(), window.ShortDebugString().c_str(),
        dnums.ShortDebugString().c_str());
  }

  Shape base_shape =
      ShapeUtil::MakeShape(lhs.element_type(), input_spatial_dims);
  TF_ASSIGN_OR_RETURN(
      Shape window_output_shape,
      InferWindowOutputShape(base_shape, window, lhs.element_type(),
                             /*allow_negative_padding=*/true));

  std::vector<int64> dimensions(num_dims);
  dimensions[dnums.output_batch_dimension()] = input_batch;
  dimensions[dnums.output_feature_dimension()] = kernel_output_features;
  for (int i = 0; i < num_spatial_dims; ++i) {
    dimensions[dnums.spatial_dimensions(i)] = window_output_shape.dimensions(i);
  }

  return ShapeUtil::MakeShape(lhs.element_type(), dimensions);
}

/* static */ StatusOr<Shape> ShapeInference::InferCrossReplicaSumShape(
    const Shape& operand) {
  TF_RETURN_IF_ERROR(
      ExpectNotTupleOrOpaque(operand, "operand of cross replica sum"));
  return operand;
}

/* static */ StatusOr<Shape> ShapeInference::InferReduceShape(
    const Shape& arg, const Shape& init_value,
    tensorflow::gtl::ArraySlice<int64> dimensions_to_reduce,
    const ProgramShape& to_apply) {
  // Check that the dimension to reduce are in-bounds for the given shape.
  for (int64 dimension : dimensions_to_reduce) {
    if (dimension >= ShapeUtil::Rank(arg) || dimension < 0) {
      return InvalidArgument(
          "attempting to reduce out-of-bounds dimension %lld in shape %s",
          dimension, ShapeUtil::HumanString(arg).c_str());
    }
  }
  TF_RETURN_IF_ERROR(
      VerifyReducerShape(to_apply, init_value, arg.element_type()));

  std::set<int64> dimensions_to_reduce_set(dimensions_to_reduce.begin(),
                                           dimensions_to_reduce.end());
  std::vector<int64> new_dimensions;
  for (int i = 0; i < ShapeUtil::Rank(arg); ++i) {
    if (dimensions_to_reduce_set.find(i) == dimensions_to_reduce_set.end()) {
      new_dimensions.push_back(arg.dimensions(i));
    }
  }

  return ShapeUtil::MakeShape(to_apply.result().element_type(), new_dimensions);
}

/* static */ StatusOr<Shape> ShapeInference::InferReduceWindowShape(
    const Shape& operand_shape, const Shape& init_value_shape,
    const Window& window, const ProgramShape& to_apply_shape) {
  TF_RETURN_IF_ERROR(
      ExpectNotTupleOrOpaque(operand_shape, "operand of reduce-window"));
  TF_RETURN_IF_ERROR(VerifyReducerShape(to_apply_shape, init_value_shape,
                                        operand_shape.element_type()));
  return InferWindowOutputShape(operand_shape, window,
                                init_value_shape.element_type(),
                                /*allow_negative_padding=*/false);
}

/* static */ StatusOr<Shape> ShapeInference::InferSelectAndScatterShape(
    const Shape& operand_shape, const ProgramShape& select_shape,
    const Window& window, const Shape& source_shape,
    const Shape& init_value_shape, const ProgramShape& scatter_shape) {
  TF_RETURN_IF_ERROR(
      ExpectNotTupleOrOpaque(operand_shape, "operand of select-and-scatter"));

  // Check if the select function has a proper shape of (T,T) -> PRED.
  if (select_shape.parameters_size() != 2) {
    return InvalidArgument(
        "select function must take 2 parameters, but "
        "takes %d parameter(s).",
        select_shape.parameters_size());
  }
  const Shape& select_result_shape = select_shape.result();
  if (!ShapeUtil::Compatible(select_result_shape,
                             ShapeUtil::MakeShape(PRED, {}))) {
    return Unimplemented("select function must have rank-0 PRED result.");
  }
  const Shape& operand_element_shape =
      ShapeUtil::MakeShape(operand_shape.element_type(), {});
  if (!ShapeUtil::Compatible(operand_element_shape,
                             select_shape.parameters(0))) {
    return InvalidArgument(
        "select function's first parameter shape currently must "
        "match the operand element shape. Got %s vs %s",
        ShapeUtil::HumanString(select_shape.parameters(0)).c_str(),
        ShapeUtil::HumanString(operand_element_shape).c_str());
  }
  if (!ShapeUtil::Compatible(operand_element_shape,
                             select_shape.parameters(1))) {
    return InvalidArgument(
        "select function's second parameter shape currently must "
        "match the operand element shape. Got %s vs %s",
        ShapeUtil::HumanString(select_shape.parameters(1)).c_str(),
        ShapeUtil::HumanString(operand_element_shape).c_str());
  }

  // Check if the scatter function has a proper shape as a reduction.
  TF_RETURN_IF_ERROR(VerifyReducerShape(scatter_shape, init_value_shape,
                                        source_shape.element_type()));

  // Check if the result shape of window operation matches the source shape.
  TF_ASSIGN_OR_RETURN(const Shape& window_result_shape,
                      InferWindowOutputShape(operand_shape, window,
                                             operand_shape.element_type(),
                                             /*allow_negative_padding=*/false));
  if (!ShapeUtil::Compatible(source_shape, window_result_shape)) {
    return InvalidArgument(
        "source shape does not match the shape of window-reduced operand: "
        "source(%s), window-reduced operand(%s)",
        ShapeUtil::HumanString(source_shape).c_str(),
        ShapeUtil::HumanString(window_result_shape).c_str());
  }
  return operand_shape;
}

/* static */ StatusOr<Shape> ShapeInference::InferSliceShape(
    const Shape& arg, tensorflow::gtl::ArraySlice<int64> starts,
    tensorflow::gtl::ArraySlice<int64> limits,
    tensorflow::gtl::ArraySlice<int64> strides) {
  TF_RETURN_IF_ERROR(ExpectNotTupleOrOpaque(arg, "operand of slice"));
  VLOG(2) << tensorflow::strings::Printf(
      "slicing shape %s starts={%s} limits={%s}",
      ShapeUtil::HumanString(arg).c_str(),
      tensorflow::str_util::Join(starts, ", ").c_str(),
      tensorflow::str_util::Join(limits, ", ").c_str());

  if (starts.size() != limits.size()) {
    return InvalidArgument("slice start and limit sizes differ: %zu vs %zu",
                           starts.size(), limits.size());
  }

  if (starts.size() != strides.size()) {
    return InvalidArgument("slice start and strides sizes differ: %zu vs %zu",
                           starts.size(), strides.size());
  }

  if (starts.size() != ShapeUtil::Rank(arg)) {
    return InvalidArgument(
        "slice index count does not match argument rank: %zu vs %lld",
        starts.size(), ShapeUtil::Rank(arg));
  }

  std::vector<int64> sizes;
  for (int64 dimension = 0; dimension < starts.size(); ++dimension) {
    int64 start_index = starts[dimension];
    int64 limit_index = limits[dimension];
    int64 stride = strides[dimension];
    if (start_index < 0) {
      return InvalidArgument("negative start index to slice: %lld",
                             start_index);
    }
    if (limit_index > arg.dimensions(dimension)) {
      return InvalidArgument(
          "limit index (%lld) must be less than or equal to dimension "
          "size (%lld)",
          limit_index, arg.dimensions(dimension));
    }
    VLOG(2) << tensorflow::strings::Printf("starts[%lld] = %lld", dimension,
                                           start_index);
    VLOG(2) << tensorflow::strings::Printf("limits[%lld] = %lld", dimension,
                                           limit_index);
    if (start_index > limit_index) {
      return InvalidArgument(
          "limit index (%lld) must be greater or equal to "
          "start index (%lld) in slice with positive stride",
          limit_index, start_index);
    }
    if (stride <= 0) {
      return InvalidArgument("stride (%lld) must be positive", stride);
    }
    sizes.push_back((limit_index - start_index + stride - 1) / stride);
  }

  return ShapeUtil::MakeShape(arg.element_type(), sizes);
}

/* static */ StatusOr<Shape> ShapeInference::InferDynamicSliceShape(
    const Shape& operand_shape, const Shape& start_indices_shape,
    tensorflow::gtl::ArraySlice<int64> slice_sizes) {
  TF_RETURN_IF_ERROR(
      ExpectNotTupleOrOpaque(operand_shape, "operand of dynamic slice"));
  TF_RETURN_IF_ERROR(ExpectNotTupleOrOpaque(start_indices_shape,
                                            "start indices of dynamic slice"));

  VLOG(2) << tensorflow::strings::Printf(
      "slicing shape %s at dynamic start_indices %s with slice_sizes={%s}",
      ShapeUtil::HumanString(operand_shape).c_str(),
      ShapeUtil::HumanString(start_indices_shape).c_str(),
      tensorflow::str_util::Join(slice_sizes, ", ").c_str());

  if (ShapeUtil::Rank(start_indices_shape) != 1) {
    return InvalidArgument(
        "dynamic slice start indices of rank %lld must be rank1.",
        ShapeUtil::Rank(start_indices_shape));
  }

  if (!ShapeUtil::ElementIsIntegral(start_indices_shape)) {
    return InvalidArgument(
        "dynamic slice start indices must be of integral type.");
  }

  const int64 start_num_dims = start_indices_shape.dimensions(0);
  if (ShapeUtil::Rank(operand_shape) != start_num_dims) {
    return InvalidArgument(
        "dynamic slice start number of dimensions %lld (%s) must match rank "
        "%lld of slice input (%s)",
        start_num_dims, ShapeUtil::HumanString(start_indices_shape).c_str(),
        ShapeUtil::Rank(operand_shape),
        ShapeUtil::HumanString(operand_shape).c_str());
  }

  if (slice_sizes.size() != ShapeUtil::Rank(operand_shape)) {
    return InvalidArgument(
        "dynamic slice index count does not match argument rank: %zu vs %lld",
        slice_sizes.size(), ShapeUtil::Rank(operand_shape));
  }

  for (int64 dim = 0; dim < slice_sizes.size(); ++dim) {
    const int64 input_dim_size = operand_shape.dimensions(dim);
    const int64 slice_dim_size = slice_sizes[dim];
    if (slice_dim_size < 0) {
      return InvalidArgument("negative size index to dynamic slice: %lld",
                             slice_dim_size);
    }
    if (slice_dim_size > input_dim_size) {
      return InvalidArgument(
          "slice dim size %lld greater than dynamic slice dimension: %lld",
          slice_dim_size, input_dim_size);
    }
    VLOG(2) << tensorflow::strings::Printf("slice_sizes[%lld] = %lld", dim,
                                           slice_dim_size);
  }

  return ShapeUtil::MakeShape(operand_shape.element_type(), slice_sizes);
}

/* static */ StatusOr<Shape> ShapeInference::InferDynamicUpdateSliceShape(
    const Shape& operand_shape, const Shape& update_shape,
    const Shape& start_indices_shape) {
  TF_RETURN_IF_ERROR(
      ExpectNotTupleOrOpaque(operand_shape, "operand of dynamic update slice"));
  TF_RETURN_IF_ERROR(
      ExpectNotTupleOrOpaque(update_shape, "update of dynamic update slice"));
  TF_RETURN_IF_ERROR(ExpectNotTupleOrOpaque(
      start_indices_shape, "start indices of dynamic update slice"));

  VLOG(2) << tensorflow::strings::Printf(
      "updating slice of shape %s at dynamic start_indices %s with update "
      "shape %s",
      ShapeUtil::HumanString(operand_shape).c_str(),
      ShapeUtil::HumanString(start_indices_shape).c_str(),
      ShapeUtil::HumanString(update_shape).c_str());

  if (ShapeUtil::Rank(start_indices_shape) != 1) {
    return InvalidArgument(
        "dynamic update slice start indices of rank %lld must be rank1.",
        ShapeUtil::Rank(start_indices_shape));
  }

  if (!ShapeUtil::ElementIsIntegral(start_indices_shape)) {
    return InvalidArgument(
        "dynamic update slice start indices must be of integral type.");
  }

  const int64 start_num_dims = start_indices_shape.dimensions(0);
  if (ShapeUtil::Rank(operand_shape) != start_num_dims) {
    return InvalidArgument(
        "dynamic slice start number of dimensions %lld (%s) must match rank "
        "%lld of slice input (%s)",
        start_num_dims, ShapeUtil::HumanString(start_indices_shape).c_str(),
        ShapeUtil::Rank(operand_shape),
        ShapeUtil::HumanString(operand_shape).c_str());
  }

  if (ShapeUtil::Rank(update_shape) != ShapeUtil::Rank(operand_shape)) {
    return InvalidArgument(
        "dynamic update slice update rank does not match argument rank: "
        "%lld vs %lld",
        ShapeUtil::Rank(update_shape), ShapeUtil::Rank(operand_shape));
  }

  if (operand_shape.element_type() != update_shape.element_type()) {
    return InvalidArgument(
        "dynamic update slice update element type does not match argument. "
        "operand.element_type: %s vs update.element_type: %s",
        PrimitiveType_Name(operand_shape.element_type()).c_str(),
        PrimitiveType_Name(update_shape.element_type()).c_str());
  }

  for (int64 dim = 0; dim < ShapeUtil::Rank(operand_shape); ++dim) {
    const int64 input_dim_size = operand_shape.dimensions(dim);
    const int64 update_dim_size = update_shape.dimensions(dim);
    if (update_dim_size < 0) {
      return InvalidArgument(
          "size index %lld to dynamic update slice must be >= 0",
          update_dim_size);
    }
    if (update_dim_size > input_dim_size) {
      return InvalidArgument(
          "update dim size %lld greater than dynamic slice dimension: %lld",
          update_dim_size, input_dim_size);
    }
    VLOG(2) << tensorflow::strings::Printf("update_sizes[%lld] = %lld", dim,
                                           update_dim_size);
  }

  return operand_shape;
}

/*static */ StatusOr<Shape> ShapeInference::InferReverseShape(
    const Shape& operand_shape, tensorflow::gtl::ArraySlice<int64> dimensions) {
  TF_RETURN_IF_ERROR(
      ExpectNotTupleOrOpaque(operand_shape, "operand of reverse"));
  if (!AllUnique(dimensions)) {
    return InvalidArgument("a dimension number is duplicated in reverse");
  }
  for (int64 dimension : dimensions) {
    if (dimension >= ShapeUtil::Rank(operand_shape) || dimension < 0) {
      return InvalidArgument(
          "one of the reverse dimensions (%lld) is out-of-bounds in shape %s",
          dimension, ShapeUtil::HumanString(operand_shape).c_str());
    }
  }
  return operand_shape;
}

/* static */ StatusOr<Shape> ShapeInference::InferGetTupleElementShape(
    const Shape& arg, int64 index) {
  if (!ShapeUtil::IsTuple(arg)) {
    return InvalidArgument(
        "cannot infer shape: attempting to index into non-tuple: %s",
        ShapeUtil::HumanString(arg).c_str());
  }

  if (index >= arg.tuple_shapes_size()) {
    return InvalidArgument(
        "cannot infer shape: attempt to index out of tuple bounds: %lld "
        ">= %d in shape %s",
        index, arg.tuple_shapes_size(), ShapeUtil::HumanString(arg).c_str());
  }

  return arg.tuple_shapes(index);
}

/* static */ StatusOr<Shape> ShapeInference::InferWhileShape(
    const ProgramShape& condition, const ProgramShape& body,
    const Shape& init) {
  // Check the number of parameters for given computations.
  if (condition.parameters_size() != 1) {
    return InvalidArgument("condition must take 1 arguments; got %d",
                           condition.parameters_size());
  }
  if (body.parameters_size() != 1) {
    return InvalidArgument("body must take 1 arguments; got %d",
                           body.parameters_size());
  }

  auto shape_string = [&]() {
    return tensorflow::strings::Printf(
        "condition: %s; body: %s; init: %s",
        ShapeUtil::HumanString(condition).c_str(),
        ShapeUtil::HumanString(body).c_str(),
        ShapeUtil::HumanString(init).c_str());
  };

  // Check the shapes of computation parameters and return types.
  if (!ShapeUtil::ShapeIs(condition.result(), PRED, {})) {
    return InvalidArgument("condition must return a boolean; got %s",
                           shape_string().c_str());
  }
  if (!ShapeUtil::Compatible(body.result(), condition.parameters(0)) ||
      !ShapeUtil::Compatible(body.result(), body.parameters(0)) ||
      !ShapeUtil::Compatible(body.result(), init)) {
    return InvalidArgument(
        "the parameter of condition and body, the result of the body, and init "
        "must all have the same shape; got %s",
        shape_string().c_str());
  }

  return init;
}

/* static */ StatusOr<Shape> ShapeInference::InferBroadcastShape(
    const Shape& operand, tensorflow::gtl::ArraySlice<int64> broadcast_sizes) {
  TF_RETURN_IF_ERROR(ExpectNotTupleOrOpaque(operand, "operand of broadcast"));
  for (int64 size : broadcast_sizes) {
    if (size < 0) {
      return InvalidArgument("Broadcast with negative dimension size %lld.",
                             size);
    }
  }

  std::vector<int64> dimensions(operand.dimensions_size() +
                                broadcast_sizes.size());
  std::copy(broadcast_sizes.begin(), broadcast_sizes.end(), dimensions.begin());
  std::copy(operand.dimensions().begin(), operand.dimensions().end(),
            dimensions.begin() + broadcast_sizes.size());
  return ShapeUtil::MakeShape(operand.element_type(), dimensions);
}

/* static */ StatusOr<Shape> ShapeInference::InferReshapeShape(
    const Shape& operand, tensorflow::gtl::ArraySlice<int64> dimensions,
    tensorflow::gtl::ArraySlice<int64> new_sizes) {
  TF_RETURN_IF_ERROR(ExpectNotTupleOrOpaque(operand, "reshape"));

  Shape inferred_shape =
      ShapeUtil::MakeShape(operand.element_type(), new_sizes);
  VLOG(3) << "Reshape inferred shape: "
          << ShapeUtil::HumanString(inferred_shape);

  if (ShapeUtil::ElementsIn(operand) != ShapeUtil::ElementsIn(inferred_shape)) {
    return InvalidArgument(
        "reshape operation has mismatched element counts: from=%lld (%s) "
        "to=%lld (%s)",
        ShapeUtil::ElementsIn(operand), ShapeUtil::HumanString(operand).c_str(),
        ShapeUtil::ElementsIn(inferred_shape),
        ShapeUtil::HumanString(inferred_shape).c_str());
  }

  std::vector<int64> indices(ShapeUtil::Rank(operand));
  std::iota(indices.begin(), indices.end(), 0);
  if (dimensions.size() != ShapeUtil::Rank(operand) ||
      !std::is_permutation(dimensions.begin(), dimensions.end(),
                           indices.begin())) {
    return InvalidArgument(
        "Reshape dimensions [%s] are not a permutation of the operand "
        "dimensions (operand shape is %s).",
        tensorflow::str_util::Join(dimensions, ",").c_str(),
        ShapeUtil::HumanString(operand).c_str());
  }

  return inferred_shape;
}

/* static */ StatusOr<Shape> ShapeInference::InferTransposeShape(
    const Shape& operand, tensorflow::gtl::ArraySlice<int64> dimensions) {
  TF_RETURN_IF_ERROR(ExpectNotTupleOrOpaque(operand, "transpose"));

  std::vector<int64> indices(ShapeUtil::Rank(operand));
  std::iota(indices.begin(), indices.end(), 0);
  if (dimensions.size() != ShapeUtil::Rank(operand) ||
      !std::is_permutation(dimensions.begin(), dimensions.end(),
                           indices.begin())) {
    return InvalidArgument(
        "Transpose dimensions not a permutation of the operand dimensions.");
  }

  // Permute(dimensions,input) computes output[dimensions[i]]=input[i]. However,
  // we need output[i]=input[dimensions[i]] which is
  // Permute(Inverse(dimensions),input).
  return ShapeUtil::PermuteDimensions(InversePermutation(dimensions), operand);
}

// TODO(b/36794510): Make broadcast semantics more consistent, by supporting
// "degenerate" cases, as with binary elementwise ops.
/* static */ StatusOr<Shape> ShapeInference::InferClampShape(
    const Shape& min, const Shape& operand, const Shape& max) {
  TF_RETURN_IF_ERROR(ExpectNotTupleOrOpaque(min, "clamp min"));
  TF_RETURN_IF_ERROR(ExpectNotTupleOrOpaque(operand, "clamp operand"));
  TF_RETURN_IF_ERROR(ExpectNotTupleOrOpaque(max, "clamp max"));
  if (!ShapeUtil::SameElementType(min, operand) ||
      !ShapeUtil::SameElementType(max, operand)) {
    return InvalidArgument("clamp op with different operand types: %s, %s, %s",
                           ShapeUtil::HumanString(min).c_str(),
                           ShapeUtil::HumanString(operand).c_str(),
                           ShapeUtil::HumanString(max).c_str());
  }
  if (((ShapeUtil::Compatible(min, operand) || ShapeUtil::IsScalar(min)) &&
       (ShapeUtil::Compatible(max, operand) || ShapeUtil::IsScalar(max)))) {
    return operand;
  }
  if (ShapeUtil::IsScalar(operand)) {
    if (ShapeUtil::Compatible(min, max)) {
      return min;
    } else if (ShapeUtil::IsScalar(min)) {
      return max;
    } else if (ShapeUtil::IsScalar(max)) {
      return min;
    }
  }
  return Unimplemented(
      "not yet implemented: %s, %s <clamp> %s", min.ShortDebugString().c_str(),
      max.ShortDebugString().c_str(), operand.ShortDebugString().c_str());
}

// TODO(b/36794510): Make broadcast semantics more consistent, by supporting
// "degenerate" cases, as with binary elementwise ops, as well as scalar
// broadcast from all operands, not just the predicate.
/* static */ StatusOr<Shape> ShapeInference::InferSelectShape(
    const Shape& pred, const Shape& on_true, const Shape& on_false) {
  if (!ShapeUtil::Compatible(on_true, on_false)) {
    return InvalidArgument(
        "operands to select must be the same shape; got %s and %s",
        ShapeUtil::HumanString(on_true).c_str(),
        ShapeUtil::HumanString(on_false).c_str());
  }
  if (pred.element_type() != PRED) {
    return InvalidArgument(
        "select's pred operand must have PRED element type; got %s",
        ShapeUtil::HumanString(pred).c_str());
  }
  if (ShapeUtil::SameDimensions(pred, on_true) || ShapeUtil::Rank(pred) == 0) {
    // By this stage we know that pred's element type is PRED. Therefore, this
    // check restricts pred to be a PRED scalar, or a PRED array with the same
    // dimensions as on_true and on_false.
    return on_true;
  } else {
    return Unimplemented(
        "select operation with non-scalar predicate with dimensionality "
        " different from the other operands: %s",
        ShapeUtil::HumanString(pred).c_str());
  }
}

/* static */ StatusOr<Shape> ShapeInference::InferCallShape(
    tensorflow::gtl::ArraySlice<const Shape*> arg_shapes,
    const ProgramShape& to_apply) {
  // The applied function's arity equals the number of arguments.
  if (arg_shapes.size() != to_apply.parameters_size()) {
    string computation_signature = ShapeUtil::HumanString(to_apply);
    string argument_shapes = tensorflow::str_util::Join(
        arg_shapes, ", ", [](string* out, const Shape* shape) {
          tensorflow::strings::StrAppend(out, ShapeUtil::HumanString(*shape));
        });
    return InvalidArgument(
        "Call applied function arity must match number of arguments; got: "
        "arity: %d, arguments: %zu; computation signature: %s; argument "
        "shapes: [%s]",
        to_apply.parameters_size(), arg_shapes.size(),
        computation_signature.c_str(), argument_shapes.c_str());
  }

  // All arguments must be compatible with the program shape.
  for (int i = 0; i < arg_shapes.size(); ++i) {
    const Shape& arg_shape = *arg_shapes[i];
    const Shape& param_shape = to_apply.parameters(i);
    if (!ShapeUtil::Compatible(arg_shape, param_shape)) {
      return InvalidArgument(
          "Call parameter must match argument; got parameter %d shape: %s, "
          "argument shape: %s",
          i, ShapeUtil::HumanString(param_shape).c_str(),
          ShapeUtil::HumanString(arg_shape).c_str());
    }
  }

  return to_apply.result();
}

}  // namespace xla