aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/compiler/xla/shape_util.h
blob: 73f541d50512523b0c5ddd76a9c0427c39c0824f (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
584
585
586
587
588
589
590
591
592
593
594
595
596
597
598
599
600
601
602
603
604
605
606
607
608
609
610
611
612
613
614
615
616
617
618
619
620
621
622
623
624
625
626
627
628
629
630
631
632
633
634
635
636
637
638
639
640
641
642
643
644
645
646
647
648
649
650
651
652
653
654
655
656
657
658
659
660
661
662
663
664
665
666
667
668
669
670
671
672
673
674
675
676
677
678
679
680
681
682
683
684
685
686
687
688
689
690
691
692
693
694
695
696
697
698
699
700
701
702
703
704
705
706
707
708
709
710
711
712
713
714
715
716
717
718
719
720
721
722
723
724
725
726
727
728
729
730
731
732
733
734
735
736
737
738
739
740
741
742
743
744
745
746
747
748
749
750
751
752
753
754
755
756
757
758
759
760
761
762
763
764
765
766
767
768
769
770
771
772
773
774
775
776
777
778
779
/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.

Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at

    http://www.apache.org/licenses/LICENSE-2.0

Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/

// Shapes are protobuf messages, so this utility header offers a bunch of
// functionality for querying / poking at them.

#ifndef TENSORFLOW_COMPILER_XLA_SHAPE_UTIL_H_
#define TENSORFLOW_COMPILER_XLA_SHAPE_UTIL_H_

#include <initializer_list>
#include <string>

#include "absl/base/macros.h"
#include "absl/container/inlined_vector.h"
#include "absl/types/optional.h"
#include "absl/types/span.h"
#include "tensorflow/compiler/xla/layout_util.h"
#include "tensorflow/compiler/xla/primitive_util.h"
#include "tensorflow/compiler/xla/status_macros.h"
#include "tensorflow/compiler/xla/statusor.h"
#include "tensorflow/compiler/xla/types.h"
#include "tensorflow/compiler/xla/util.h"
#include "tensorflow/compiler/xla/xla_data.pb.h"
#include "tensorflow/core/lib/core/threadpool.h"
#include "tensorflow/core/platform/cpu_info.h"
#include "tensorflow/core/platform/env.h"
#include "tensorflow/core/platform/macros.h"
#include "tensorflow/core/platform/types.h"

namespace xla {

// An index for specifying a particular nested subshape within a shape. Used in
// ShapeUtil::GetSubshape and other interfaces. Shapes are recursive data
// structures (trees) and ShapeIndex defines a path through the tree where each
// element of ShapeIndex indexes into a tuple (or nested tuple) within the
// shape. For a non-nested tuple, an index has a single element. For example,
// given a 3-element tuple (a, b, c) containing arrays a, b, and c, the index
// {1} corresponds to array b. For a nested tuple, the index can have more than
// one element. For the nested tuple (a, (b, c, d), e) below are the values
// corresponding to the given indices:
//
//   index {0}    : array a
//   index {1, 2} : array d
//   index {2}    : array e
//   index {0, 0} : invalid index (element at {0} is an array not a tuple)
//
// For indexing into array shapes, the index is always trivially empty, ie {}.
//
// ShapeIndex is a trivial wrapper around std::vector with a minimum number of
// methods implemented.
class ShapeIndex {
 public:
  ShapeIndex() = default;
  ShapeIndex(std::initializer_list<int64> init) : indices_(init) {}
  template <typename InputIt>
  ShapeIndex(InputIt start, InputIt end) : indices_(start, end) {}

  bool empty() const { return indices_.empty(); }
  size_t size() const { return indices_.size(); }
  void push_back(int64 value) { indices_.push_back(value); }
  void pop_back() { indices_.pop_back(); }

  // push_front is O(n^2), but shapes don't usually have a ton of dimensions.
  void push_front(int64 value) { indices_.insert(indices_.begin(), value); }

  using container_type = absl::InlinedVector<int64, 2>;

  container_type::const_iterator begin() const { return indices_.begin(); }
  container_type::const_iterator end() const { return indices_.end(); }
  container_type::iterator begin() { return indices_.begin(); }
  container_type::iterator end() { return indices_.end(); }

  const int64* data() const { return indices_.data(); }

  int64 back() const { return indices_.back(); }
  int64& back() { return indices_.back(); }

  const int64& operator[](size_t i) const { return indices_[i]; }
  int64& operator[](size_t i) { return indices_[i]; }

  bool operator==(const ShapeIndex& other) const {
    return indices_ == other.indices_;
  }
  bool operator!=(const ShapeIndex& other) const { return !(*this == other); }
  bool operator<(const ShapeIndex& other) const {
    return indices_ < other.indices_;
  }

  string ToString() const;

 private:
  container_type indices_;
};

// A view into a ShapeIndex as above, with the cheap/easy ability to consume the
// value at the front of the view.
//
// NB! ShapeIndexView does not own the memory backing the index array.
// The memory backing the index array should be owned by an object
// that lives longer than the ShapeIndexView instances pointing into
// it.
class ShapeIndexView {
 public:
  ShapeIndexView(const ShapeIndex& shape_index, int64 offset = 0)
      : indices_(shape_index.data() + offset, shape_index.size() - offset) {
    CHECK_LE(offset, shape_index.size());
  }
  ShapeIndexView(std::initializer_list<int64> indices) : indices_(indices) {}
  ShapeIndexView(const ShapeIndexView& other) = default;

  using iterator = const int64*;

  iterator begin() const { return indices_.begin(); }
  iterator end() const { return indices_.end(); }
  int64 size() const { return indices_.size(); }
  bool empty() const { return indices_.empty(); }
  int64 front() const {
    CHECK(!empty());
    return indices_.front();
  }
  ShapeIndexView ConsumeFront() const {
    ShapeIndexView result = *this;
    result.indices_.remove_prefix(1);
    return result;
  }
  ShapeIndexView ConsumeBack() const {
    ShapeIndexView result = *this;
    result.indices_.remove_suffix(1);
    return result;
  }
  ShapeIndex ToShapeIndex() const { return ShapeIndex(begin(), end()); }

  bool operator==(const ShapeIndexView& other) const;
  bool operator!=(const ShapeIndexView& other) const;

  string ToString() const;

 private:
  absl::Span<const int64> indices_;
};

std::ostream& operator<<(std::ostream& out, const ShapeIndex& shape_index);
std::ostream& operator<<(std::ostream& out, const ShapeIndexView& shape_index);

// Namespaced collection of (static) shape utilities.
//
// These are all effectively convenience functions for testing/tweaking proto
// properties, which do invariant checks before / after the operation.
class ShapeUtil {
 public:
  // Data structure which describes the coordinates and the shape, of a tuple
  // shaped sub-shape.
  struct IndexedShape {
    IndexedShape() = default;
    IndexedShape(ShapeIndex index, Shape shape)
        : index(std::move(index)), shape(std::move(shape)) {}
    ShapeIndex index;
    Shape shape;
  };

  // Returns the number of elements are contained within the provided shape;
  // e.g. for rank 0 (scalars) the result is always 1. Note that sparse shapes
  // may not actually be able to store this number of elements. See
  // LayoutUtil::MaxSparseElements(shape) to obtain the maximum number of
  // elements that can be stored in a sparse shape.
  // Precondition: IsArray(shape)
  static int64 ElementsIn(const Shape& shape);

  // As ElementsIn(), but recurses through tuples.
  static int64 ElementsInRecursive(const Shape& shape);

  // Returns true if shape has the primitive type, recurses through tuples.
  static bool HasPrimitiveType(const Shape& shape,
                               PrimitiveType primitive_type);

  // Returns true if 'shape' is an array with zero elements.
  static bool IsZeroElementArray(const Shape& shape);

  // Returns the number of bytes required for an allocation of shape.  The
  // |pointer_size| parameter is used for calculating the size of tuple
  // shapes. This includes only the size of the top-level buffer. For example, a
  // tuple is stored as an array of pointers to other buffers. In this case,
  // this method only returns the size of the pointer array.
  static int64 ByteSizeOf(const Shape& shape, int64 pointer_size = -1);

  // Returns the number of bytes used to store the primitive_type.
  //
  // Precondition: ShapeUtil::IsArray(shape)
  static int64 ByteSizeOfPrimitiveType(PrimitiveType primitive_type);

  // Returns the number of bytes required to store the tuple member pointers for
  // a allocation of shape. The `shape` must be a TUPLE shape, and
  // `pointer_size` must be larger than zero.
  static int64 ByteSizeOfTupleIndexTable(const Shape& shape,
                                         int64 pointer_size);

  // Returns the number of bytes required for the elements in an allocation of
  // `shape`, which must be an array shape. The return value does not include
  // the bytes needed to store sparse indices. Dense shapes use a separate
  // memory location for each element, and so for these shapes,
  // `ByteSizeOf(shape) == ByteSizeOfElements(shape)`. For dense shapes, this
  // size also includes padding if present in the layout. For sparse shapes,
  // `ByteSizeOf(shape) == ByteSizeOfElements(shape) +
  // ByteSizeOfSparseindices(shape)`.
  static int64 ByteSizeOfElements(const Shape& shape);

  // Returns the number of bytes required for the sparse indices in an
  // allocation of shape. The shape must be an array shape. The return value
  // does not include the bytes needed to store sparse indices.
  static int64 ByteSizeOfSparseIndices(const Shape& shape);

  // Returns a human-readable string that represents the given shape, with or
  // without layout. e.g. "f32[42x12] {0, 1}" or "f32[64]".
  static string HumanString(const Shape& shape);
  static string HumanStringWithLayout(const Shape& shape);

  // As above, but for program shapes, returns a string for the form:
  //
  // (param_name: f32[42x12], ...) -> f32[24x42]
  static string HumanString(const ProgramShape& program_shape);

  // Parses a ShapeUtil::HumanString-format shape string back into a shape
  // object.
  static StatusOr<Shape> ParseShapeString(absl::string_view s);

  // Returns whether the LHS and RHS shapes have the same dimensions; note: does
  // not check element type.
  // Precondition: IsArray(lhs) && IsArray(rhs)
  static bool SameDimensions(const Shape& lhs, const Shape& rhs);

  // Returns whether the lhs and rhs shapes have the same element type.
  static bool SameElementType(const Shape& lhs, const Shape& rhs) {
    return lhs.element_type() == rhs.element_type();
  }

  // As SameElementType, but allows floating point types to have different
  // precisions.
  static bool SameElementTypeIgnoringFpPrecision(const Shape& a,
                                                 const Shape& b) {
    if (ElementIsFloating(a) && ElementIsFloating(b)) {
      return true;
    }
    return ShapeUtil::SameElementType(a, b);
  }

  // Returns the higher-precision element type if a and b are both floating
  // point types; otherwise, checks that that they have the same element type
  // and returns it.
  static PrimitiveType HigherPrecisionElementType(const Shape& a,
                                                  const Shape& b) {
    if (SameElementType(a, b)) {
      return a.element_type();
    }
    CHECK(SameElementTypeIgnoringFpPrecision(a, b));
    return primitive_util::BitWidth(a.element_type()) <
                   primitive_util::BitWidth(b.element_type())
               ? b.element_type()
               : a.element_type();
  }

  // Returns true if the rank, dimension sizes, and element type are
  // identical. Layout is ignored. Tuple elements are compared recursively for
  // compatibility.
  static bool Compatible(const Shape& lhs, const Shape& rhs);

  // Returns true if the rank and dimension sizes are identical. Element type
  // and layout are ignored. Tuple elements are compared recursively for
  // compatibility.
  static bool CompatibleIgnoringElementType(const Shape& lhs, const Shape& rhs);

  // As Compatible, but allow one of lhs and rhs to be BF16 while the other
  // being F32. Tuple elements are compared recursively for compatibility.
  static bool CompatibleIgnoringFpPrecision(const Shape& lhs, const Shape& rhs);

  // Returns whether the lhs and rhs shapes are identical protobufs.
  static bool Equal(const Shape& lhs, const Shape& rhs);

  // As Equal, but allow one of lhs and rhs to be F16 while the other is F32.
  static bool EqualIgnoringFpPrecision(const Shape& lhs, const Shape& rhs);

  // Returns the rank (number of dimensions) of the given shape.
  // Precondition: !IsTuple(shape)
  static int64 Rank(const Shape& shape);

  // Returns the number of dimensions for which the dimension is not (trivially)
  // 1. e.g., f32[2x1x1] has a true rank of 1D, the other dimensions are just
  // fluff. Note that zero dimensions are included in the true rank, e.g.,
  // f32[3,0,1] has a true rank of 2D.
  static int64 TrueRank(const Shape& shape);

  static ProgramShape MakeProgramShape(std::initializer_list<Shape> parameters,
                                       Shape result);

  ////////////////////
  // Scalar-specific

  static bool IsScalar(const Shape& shape) {
    return IsArray(shape) && Rank(shape) == 0;
  }
  static bool IsEffectiveScalar(const Shape& shape) {
    return IsArray(shape) && TrueRank(shape) == 0;
  }

  // Returns whether "shape" is a scalar (array) with the given element_type.
  static bool IsScalarWithElementType(const Shape& shape,
                                      PrimitiveType element_type);

  // Extracts the size of the shape's dimension at dimension number
  // GetDimensionNumber(dimension_number).
  static int64 GetDimension(const Shape& shape, int64 dimension_number);

  // Resolves a dimension number, supporting negative indexing.
  //
  // Negative indexing has similar semantics to Python. For an N-dimensional
  // array, dimension -1 is equivalent to dimension N-1, -2 is equivalent to
  // N-2, and so on.
  //
  // This function always returns a positive dimension number for any given
  // dimension_number (which itself can be negative).
  static int64 GetDimensionNumber(const Shape& shape, int64 dimension_number);

  // Returns a shape with the same dimensions as the original, but with the
  // element type changed to type.
  static Shape ChangeElementType(const Shape& original, PrimitiveType type);

  // Creates a tuple shape from a slice of element shapes within the tuple.
  static Shape MakeTupleShape(absl::Span<const Shape> shapes);

  // Creates an opaque shape. These are generally used for threading a context
  // into a custom operation.
  static Shape MakeOpaqueShape();

  // Creates a token shape. Values of this shape are used for ordering
  // side-effecting operations.
  static Shape MakeTokenShape();

  // Appends a shape to the given tuple.
  static void AppendShapeToTuple(const Shape& shape, Shape* tuple_shape);

  // Appends a major dimension to the shape with the given bound.
  static void AppendMajorDimension(int bound, Shape* shape);

  // Returns an empty tuple shape. Can be used as a sentinel Shape value.
  static Shape MakeNil() { return MakeTupleShape({}); }

  // Checks whether the shape is initialized.
  static bool IsInitialized(const Shape& shape) {
    return shape.element_type() != PRIMITIVE_TYPE_INVALID;
  }

  // Constructs a new shape with the given element type and sequence of
  // dimensions.
  static Shape MakeShape(PrimitiveType element_type,
                         absl::Span<const int64> dimensions);

  // Creates a Shape with element type corresponding to T and the given
  // dimensions
  template <typename T>
  static Shape MakeShapeWithType(absl::Span<const int64> dimensions) {
    return ShapeUtil::MakeShape(primitive_util::NativeToPrimitiveType<T>(),
                                dimensions);
  }

  // Constructs a new shape with the given minor_to_major order in its Layout.
  // Returns a value shape such that shape.has_layout().
  static Shape MakeShapeWithLayout(PrimitiveType element_type,
                                   absl::Span<const int64> dimensions,
                                   absl::Span<const int64> minor_to_major);

  static Shape MakeShapeWithSparseLayout(PrimitiveType element_type,
                                         absl::Span<const int64> dimensions,
                                         int64 max_sparse_elements);

  // Constructs a new shape with major-first layout (i.e. {n, n-1, ..., 0}).
  static Shape MakeShapeWithDescendingLayout(
      PrimitiveType element_type, absl::Span<const int64> dimensions);

  // Returns a new Shape based on the given Shape with low-dimension-major
  // layout (i.e. {n, n-1, ..., 0}, like Fortran), and with the dimensions
  // rearranged so that it has the same in-memory layout as the given shape.
  //
  // For example, transforms f32[B,H,W,C]{0,3,2,1} to f32[H,W,C,B]{3,2,1,0}.
  static Shape MakeShapeWithDescendingLayoutAndSamePhysicalLayout(
      const Shape& shape);

  // As MakeShape, but the object to write to is passed in.
  static void PopulateShape(PrimitiveType element_type,
                            absl::Span<const int64> dimensions, Shape* shape);

  // Validates that the provided shape satisfies invariants.
  static Status ValidateShape(const Shape& shape);

  // Validates the provided shape satisfies invariants, except those that
  // pertain to layout.
  //
  // Layout is optional for client-provided shapes, so that the compiler may
  // determine and assign an optimized layout.
  static Status ValidateShapeWithOptionalLayout(const Shape& shape);

  // Returns whether the element type of the shape is integral (signed or
  // unsigned). Note that predicates are not considered integral here, since
  // they are logical values.
  static bool ElementIsIntegral(const Shape& shape);

  // Returns whether the element type of the shape is floating point.
  static bool ElementIsFloating(const Shape& shape);

  // Returns whether the element type of the shape is complex.
  static bool ElementIsComplex(const Shape& shape);

  // Returns whether the element type has the given bit width.
  static bool ElementHasBitWidth(const Shape& shape, int bits);

  // Returns whether the element type of the shape is integral and has
  // the specified number of bits.
  static bool ElementIsIntegralWithBits(const Shape& shape, int bits);

  // Returns whether the element type of the shape is signed. Note
  // that floating point numbers are signed.
  static bool ElementIsSigned(const Shape& shape);

  // Returns whether the shape is a tuple.
  static bool IsTuple(const Shape& shape) {
    return shape.element_type() == TUPLE;
  }

  // Returns whether the shape is an opaque value (i.e. an 'existential' typed
  // value that is passed to CustomCall operations).
  static bool IsOpaque(const Shape& shape) {
    return shape.element_type() == OPAQUE;
  }

  // Returns whether the shape is an token value used for ordering
  // side-effecting operations.
  static bool IsToken(const Shape& shape) {
    return shape.element_type() == TOKEN;
  }

  // Returns whether the shape is an array.  Note that scalars are considered
  // arrays.
  static bool IsArray(const Shape& shape);

  // Returns whether the shape is a tuple with at least one element which is
  // also a tuple.
  static bool IsNestedTuple(const Shape& shape);

  // Returns true if shape is an empty tuple.
  static bool IsEmptyTuple(const Shape& shape);

  // Returns true if shape is the nil shape (an empty tuple).
  static bool IsNil(const Shape& shape);

  // Returns the number of elements in the given tuple shape.
  // Precondition: IsTuple(shape)
  static int64 TupleElementCount(const Shape& shape);

  // Returns the tuple element shape at given index.
  // Precondition: IsTuple(shape) && TupleElementCount(shape) > index
  static const Shape& GetTupleElementShape(const Shape& shape, int64 index);

  // Returns the number of elements, recursively, in the given shape.
  static int64 SubshapeCount(const Shape& shape);

  // Slices tuple elements in the range [start, limit) and returns a new tuple
  // shape. E.g. a tuple like (f32, s32, u32) would slice via 1,3 to (s32, u32).
  static Shape SliceTuple(const Shape& tuple, int64 start, int64 limit);

  // Returns the shape of the real/imaginary components of the given complex
  // shape.
  static Shape ComplexComponentShape(const Shape& complex_shape);

  // Shorthand for testing whether a shape is of a given element type and
  // sequence of dimensions.
  ABSL_DEPRECATED("Use Equal() instead.")
  static bool ShapeIs(const Shape& shape, PrimitiveType element_type,
                      std::initializer_list<int64> dimensions);

  // Returns true if the given shape has a subshape at the given index.
  static bool IndexIsValid(const Shape& shape, ShapeIndexView index);

  // GetSubshape and GetMutableSubshape return a particular nested Shape within
  // the given Shape argument. The non-Try variants check fail if index is
  // invalid.
  static const Shape& GetSubshape(const Shape& shape, ShapeIndexView index);
  static StatusOr<const Shape*> TryGetSubshape(const Shape& shape,
                                               ShapeIndexView index);
  static Shape* GetMutableSubshape(Shape* shape, ShapeIndexView index);

  // Returns whether the given index in the given shape is a leaf element of the
  // shape.
  static bool IsLeafIndex(const Shape& shape, const ShapeIndex& index);

  // Returns the number of leaves in the shape.
  static int64 GetLeafCount(const Shape& shape);

  // Retrieves all the leaf shapes and their indexes, in the order walked by
  // the ForEachSubshape() API.
  static std::vector<IndexedShape> GetLeafShapes(const Shape& shape);

  // Calls the given visitor function for each subshape of the given shape.
  // Subshapes are visited in DFS pre-order starting with the entire shape
  // (index {}).
  using VisitorFunction = std::function<void(const Shape& /*subshape*/,
                                             const ShapeIndex& /*index*/)>;
  static void ForEachSubshape(const Shape& shape, const VisitorFunction& func);
  using MutatingVisitorFunction =
      std::function<void(Shape* /*subshape*/, const ShapeIndex& /*index*/)>;
  static void ForEachMutableSubshape(Shape* shape,
                                     const MutatingVisitorFunction& func);

  // Variants of ForEach(Mutable)Subshape which propagate Status from the
  // visitor function.
  using StatusVisitorFunction = std::function<Status(
      const Shape& /*subshape*/, const ShapeIndex& /*index*/)>;
  static Status ForEachSubshapeWithStatus(const Shape& shape,
                                          const StatusVisitorFunction& func);
  using MutatingStatusVisitorFunction =
      std::function<Status(Shape* /*subshape*/, const ShapeIndex& /*index*/)>;
  static Status ForEachMutableSubshapeWithStatus(
      Shape* shape, const MutatingStatusVisitorFunction& func);

  // Returns true if `shape` (which must be an array) with degenerate dimensions
  // (dimensions with bound 1).
  static bool HasDegenerateDimensions(const Shape& shape);

  // Permutes the dimensions by the given permutation, so
  // return_value.dimensions[permutation[i]] = argument.dimensions[i].
  //
  // Postcondition: For any valid permutation,
  //
  //   !HasLayout(shape) ||
  //   TransposeIsBitcast(shape, PermuteDimensions(permutation, shape),
  //                      InversePermutation(permutation)).
  static Shape PermuteDimensions(absl::Span<const int64> permutation,
                                 const Shape& shape);

  // If we can go from `shape_pre` to `shape_post` by merely inserting or
  // deleting 1-sized dimensions, return the indices in `shape_pre` of the
  // deleted dimensions and the indices in `dims_post` of the inserted
  // dimensions.
  // For example, if `shape_pre = {a_1, a_2, ..., a_m}` and
  // `shape_post = {b_1, b_2, ..., b_n}` where we can find some sequence of `i`s
  // and some sequence of `j`s so `a_i = 1` for each `i` and `b_j = 1` for each
  // `j` and `a_(k-s) = b_(k-t)` where `s` and `t` are the number of `i`s and
  // `j`s less than `k` for all other `k`, we return the `i`s and `j`s.
  // For another example, if `shape_pre = shape_post = {}`, we return `{}`.
  static std::tuple<bool, std::vector<int64>, std::vector<int64>>
  InsertedOrDeleted1SizedDimensions(const Shape& shape_pre,
                                    const Shape& shape_post);

  // Suppose a reshape transforms input_shape to output shape. Returns a vector
  // of pairs that indicate the input and output dimensions that this reshape
  // doesn't logically (i.e. ignoring the layout) modify. For each pair (I,O) in
  // the returned vector, the reshape transforms any input index whose I-th
  // dimension is x to an output index whose O-th dimension is x too.
  //
  // Post-condition: the returned vector is sorted (by both input and output
  // dimensions because input and output dimensions have the same order).
  //
  // Example:
  //   input  shape = T[a, b, x, y, cd]
  //   output shape = T[ab, x, 1, y, c, d]
  //   return value = {{2, 1}, {3, 3}}
  //
  //   The two pairs represent the input and output dimension of size x and
  //   those of size y.
  static std::vector<std::pair<int64, int64>> DimensionsUnmodifiedByReshape(
      const Shape& input_shape, const Shape& output_shape);

  // Returns whether a transpose from input_shape to output_shape with dimension
  // mapping "dimension_mapping" produces a result which is bit-wise identical
  // to its input and thus may be replaced with a bitcast.
  //
  // Precondition: Both input_shape and output_shape have explicit layouts.
  static bool TransposeIsBitcast(const Shape& input_shape,
                                 const Shape& output_shape,
                                 absl::Span<const int64> dimension_mapping);

  // Returns whether a reshape from "input_shape" to "output_shape" is a
  // bitcast.
  //
  // Precondition: Both input_shape and output_shape have explicit layouts.
  static bool ReshapeIsBitcast(const Shape& input_shape,
                               const Shape& output_shape);

  // Find a physical layout for 'output_shape' such that
  // ShapeUtil::ReshapeIsBitcast(input_shape, output_shape_with_layout) returns
  // true (where 'output_shape_with_layout' is 'output_shape' with the found
  // layout). The layout of 'input_shape' is kept fixed. Returns
  // 'output_shape_with_layout' if such a layout can be found, and an error
  // otherwise.
  static absl::optional<Shape> AlignLayouts(const Shape& input_shape,
                                            const Shape& output_shape);

  // Returns a shape with the given dimension deleted.
  // For example:
  // • `DeleteDimension(1, T[m, n, k]) = T[m, k]`
  static Shape DeleteDimension(int64 dim_to_delete, Shape shape);

  // Returns a shape with all the dimensions of the input shape for which `p`
  // returns true.
  // For examples:
  // • `FilterDimensions((< 2), T[m, n, k]) = T[m, n]`
  // • `FilterDimensions(is_even_number, T[m, n, k]) = T[m, k]`
  static Shape FilterDimensions(const std::function<bool(int64)>& p,
                                Shape shape);

  // Iterates through all the shape indexes, in minor to major order, starting
  // from the base indexes, incrementing by the incr steps, up to count
  // (index[i] < base[i] + count[i]), and calls the visitor_function with the
  // current index.
  // The visitor_function visitor function should return true if it wants to
  // continue, or false otherwise.
  //
  // visitor_function must be a callable of type
  // StatusOr<bool>(Span<int64>) or compatible.
  template <typename FnType>
  static Status ForEachIndexWithStatus(const Shape& shape,
                                       absl::Span<const int64> base,
                                       absl::Span<const int64> count,
                                       absl::Span<const int64> incr,
                                       const FnType& visitor_function) {
    return ForEachIndexInternal(shape, base, count, incr, visitor_function);
  }

  // Simple ergonomic wrapper around ShapeUtil::ForEachIndexWithStatus.
  struct IndexIterationSpace {
    std::vector<int64> index_base;
    std::vector<int64> index_count;
    std::vector<int64> index_incr;
  };

  template <typename FnTy>
  static Status ForEachIndexWithStatus(
      const Shape& shape, const IndexIterationSpace& iteration_space,
      FnTy&& function) {
    return ShapeUtil::ForEachIndexWithStatus(
        shape, iteration_space.index_base, iteration_space.index_count,
        iteration_space.index_incr, std::forward<FnTy>(function));
  }

  template <typename FnType>
  static void ForEachIndex(const Shape& shape, absl::Span<const int64> base,
                           absl::Span<const int64> count,
                           absl::Span<const int64> incr,
                           const FnType& visitor_function) {
    ForEachIndexWithStatus(shape, base, count, incr,
                           [&](absl::Span<const int64> indices) {
                             return StatusOr<bool>(visitor_function(indices));
                           })
        .IgnoreError();
  }

  // These convenience wrappers don't take `base`, `count` and `incr`
  // explicitly, but iterate over every element in `shape` instead.

  template <typename FnType>
  static Status ForEachIndexWithStatus(const Shape& shape,
                                       const FnType& visitor_function) {
    std::vector<int64> base(shape.dimensions_size());
    std::vector<int64> incr(shape.dimensions_size(), 1);
    return ForEachIndexWithStatus(shape, base,
                                  /*count=*/AsInt64Slice(shape.dimensions()),
                                  incr, visitor_function);
  }

  template <typename FnType>
  static void ForEachIndex(const Shape& shape, const FnType& visitor_function) {
    ForEachIndexWithStatus(shape,
                           [&](absl::Span<const int64> indices) {
                             return StatusOr<bool>(visitor_function(indices));
                           })
        .IgnoreError();
  }

  // A parallel version of ForEachIndex(WithStatus). This can only be used if
  // the visitor_function is thread-safe and the order of iteration does not
  // matter.
  //
  // visitor_function must be a callable of type
  // void(Span<int64>) or compatible.
  template <typename FnType>
  static void ForEachIndexParallel(const Shape& shape,
                                   absl::Span<const int64> base,
                                   absl::Span<const int64> count,
                                   absl::Span<const int64> incr,
                                   const FnType& visitor_function) {
    // The parallel version of ForEachIndexInternal can never fail.
    CHECK(ForEachIndexInternal(
              shape, base, count, incr,
              [&visitor_function](
                  absl::Span<const int64> indexes) -> StatusOr<bool> {
                visitor_function(indexes);
                return true;
              },
              /*parallel=*/true)
              .ok());
  }

  // Compute a hash for `shape`.
  static size_t Hash(const Shape& shape);

 private:
  // Validates the shape size is sane. This makes sure it's safe to do
  // calculations in int64 without overflowing.
  static Status ValidateShapeSize(const Shape& shape);

  // Validates all of the non-layout properties of the shape -- this is a helper
  // used by both the layout-optional and layout-required public method.
  static Status ValidateShapeWithOptionalLayoutInternal(const Shape& shape);

  template <typename FnType>
  static Status ForEachIndexInternal(const Shape& shape,
                                     absl::Span<const int64> base,
                                     absl::Span<const int64> count,
                                     absl::Span<const int64> incr,
                                     const FnType& visitor_function,
                                     bool parallel = false) {
    if (ShapeUtil::IsZeroElementArray(shape)) {
      return Status::OK();
    }
    CHECK_EQ(Rank(shape), base.size());
    CHECK_EQ(incr.size(), base.size());
    CHECK_EQ(count.size(), base.size());
    const int64 rank = LayoutUtil::MinorToMajor(shape).size();
    // Allows handling R0 arrays, such that the visitor function will be called
    // once with the proper empty indexes.
    int64 n = -1;
    std::vector<int64> indexes(base.begin(), base.end());
    const int kNumThreads = tensorflow::port::NumSchedulableCPUs();
    absl::optional<tensorflow::thread::ThreadPool> pool;
    if (parallel) {
      pool.emplace(tensorflow::Env::Default(), "foreach", kNumThreads);
    }

    while (n < rank) {
      if (pool != absl::nullopt) {
        pool->Schedule(
            [indexes, &visitor_function] { visitor_function(indexes); });
      } else {
        TF_ASSIGN_OR_RETURN(bool should_continue, visitor_function(indexes));
        if (!should_continue) {
          break;
        }
      }
      // Increments dimensions in minor to major order.
      for (n = 0; n < rank; ++n) {
        int64 dim = LayoutUtil::Minor(shape.layout(), n);
        indexes[dim] += incr[dim];
        if (indexes[dim] < base[dim] + count[dim]) {
          break;
        }
        indexes[dim] = base[dim];
      }
    }

    return Status::OK();
  }

  TF_DISALLOW_COPY_AND_ASSIGN(ShapeUtil);
};

std::ostream& operator<<(std::ostream& out, const Shape& shape);

}  // namespace xla

#endif  // TENSORFLOW_COMPILER_XLA_SHAPE_UTIL_H_