diff options
Diffstat (limited to 'tensorflow/compiler/xla/shape_util.h')
-rw-r--r-- | tensorflow/compiler/xla/shape_util.h | 56 |
1 files changed, 32 insertions, 24 deletions
diff --git a/tensorflow/compiler/xla/shape_util.h b/tensorflow/compiler/xla/shape_util.h index 5ae04451d3..d6f17fc965 100644 --- a/tensorflow/compiler/xla/shape_util.h +++ b/tensorflow/compiler/xla/shape_util.h @@ -31,6 +31,7 @@ limitations under the License. #include "tensorflow/compiler/xla/xla_data.pb.h" #include "tensorflow/core/lib/core/threadpool.h" #include "tensorflow/core/lib/gtl/array_slice.h" +#include "tensorflow/core/lib/gtl/inlined_vector.h" #include "tensorflow/core/lib/gtl/optional.h" #include "tensorflow/core/platform/cpu_info.h" #include "tensorflow/core/platform/env.h" @@ -73,10 +74,12 @@ class ShapeIndex { // push_front is O(n^2), but shapes don't usually have a ton of dimensions. void push_front(int64 value) { indices_.insert(indices_.begin(), value); } - std::vector<int64>::const_iterator begin() const { return indices_.begin(); } - std::vector<int64>::const_iterator end() const { return indices_.end(); } - std::vector<int64>::iterator begin() { return indices_.begin(); } - std::vector<int64>::iterator end() { return indices_.end(); } + using container_type = tensorflow::gtl::InlinedVector<int64, 2>; + + container_type::const_iterator begin() const { return indices_.begin(); } + container_type::const_iterator end() const { return indices_.end(); } + container_type::iterator begin() { return indices_.begin(); } + container_type::iterator end() { return indices_.end(); } const int64* data() const { return indices_.data(); } @@ -97,7 +100,7 @@ class ShapeIndex { string ToString() const; private: - std::vector<int64> indices_; + container_type indices_; }; // A view into a ShapeIndex as above, with the cheap/easy ability to consume the @@ -110,31 +113,33 @@ class ShapeIndex { class ShapeIndexView { public: ShapeIndexView(const ShapeIndex& shape_index, int64 offset = 0) - : ShapeIndexView(shape_index.data() + offset, - shape_index.data() + shape_index.size()) { + : indices_(shape_index.data() + offset, shape_index.size() - offset) { CHECK_LE(offset, shape_index.size()); } - ShapeIndexView(std::initializer_list<int64> indices) - : ShapeIndexView(indices.begin(), indices.end()) {} + ShapeIndexView(std::initializer_list<int64> indices) : indices_(indices) {} ShapeIndexView(const ShapeIndexView& other) = default; using iterator = const int64*; - iterator begin() const { return begin_; } - iterator end() const { return end_; } - int64 size() const { return std::distance(begin_, end_); } - bool empty() const { return begin_ == end_; } + iterator begin() const { return indices_.begin(); } + iterator end() const { return indices_.end(); } + int64 size() const { return indices_.size(); } + bool empty() const { return indices_.empty(); } int64 front() const { CHECK(!empty()); - return *begin_; + return indices_.front(); } ShapeIndexView ConsumeFront() const { - CHECK(!empty()); - auto new_begin = begin_; - ++new_begin; - return ShapeIndexView(new_begin, end_); + ShapeIndexView result = *this; + result.indices_.pop_front(); + return result; + } + ShapeIndexView ConsumeBack() const { + ShapeIndexView result = *this; + result.indices_.pop_back(); + return result; } - ShapeIndex ToShapeIndex() const { return ShapeIndex(begin_, end_); } + ShapeIndex ToShapeIndex() const { return ShapeIndex(begin(), end()); } bool operator==(const ShapeIndexView& other) const; bool operator!=(const ShapeIndexView& other) const; @@ -142,10 +147,7 @@ class ShapeIndexView { string ToString() const; private: - ShapeIndexView(iterator begin, iterator end) : begin_(begin), end_(end) {} - - iterator begin_; - iterator end_; + tensorflow::gtl::ArraySlice<int64> indices_; }; std::ostream& operator<<(std::ostream& out, const ShapeIndex& shape_index); @@ -530,7 +532,13 @@ class ShapeUtil { static bool HasDegenerateDimensions(const Shape& shape); // Permutes the dimensions by the given permutation, so - // return_value.dimensions[permutation[i]] = argument.dimensions[i] + // return_value.dimensions[permutation[i]] = argument.dimensions[i]. + // + // Postcondition: For any valid permutation, + // + // !HasLayout(shape) || + // TransposeIsBitcast(shape, PermuteDimensions(permutation, shape), + // InversePermutation(permutation)). static Shape PermuteDimensions(tensorflow::gtl::ArraySlice<int64> permutation, const Shape& shape); |