aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/compiler/xla/shape_util.h
diff options
context:
space:
mode:
Diffstat (limited to 'tensorflow/compiler/xla/shape_util.h')
-rw-r--r--tensorflow/compiler/xla/shape_util.h56
1 files changed, 32 insertions, 24 deletions
diff --git a/tensorflow/compiler/xla/shape_util.h b/tensorflow/compiler/xla/shape_util.h
index 5ae04451d3..d6f17fc965 100644
--- a/tensorflow/compiler/xla/shape_util.h
+++ b/tensorflow/compiler/xla/shape_util.h
@@ -31,6 +31,7 @@ limitations under the License.
#include "tensorflow/compiler/xla/xla_data.pb.h"
#include "tensorflow/core/lib/core/threadpool.h"
#include "tensorflow/core/lib/gtl/array_slice.h"
+#include "tensorflow/core/lib/gtl/inlined_vector.h"
#include "tensorflow/core/lib/gtl/optional.h"
#include "tensorflow/core/platform/cpu_info.h"
#include "tensorflow/core/platform/env.h"
@@ -73,10 +74,12 @@ class ShapeIndex {
// push_front is O(n^2), but shapes don't usually have a ton of dimensions.
void push_front(int64 value) { indices_.insert(indices_.begin(), value); }
- std::vector<int64>::const_iterator begin() const { return indices_.begin(); }
- std::vector<int64>::const_iterator end() const { return indices_.end(); }
- std::vector<int64>::iterator begin() { return indices_.begin(); }
- std::vector<int64>::iterator end() { return indices_.end(); }
+ using container_type = tensorflow::gtl::InlinedVector<int64, 2>;
+
+ container_type::const_iterator begin() const { return indices_.begin(); }
+ container_type::const_iterator end() const { return indices_.end(); }
+ container_type::iterator begin() { return indices_.begin(); }
+ container_type::iterator end() { return indices_.end(); }
const int64* data() const { return indices_.data(); }
@@ -97,7 +100,7 @@ class ShapeIndex {
string ToString() const;
private:
- std::vector<int64> indices_;
+ container_type indices_;
};
// A view into a ShapeIndex as above, with the cheap/easy ability to consume the
@@ -110,31 +113,33 @@ class ShapeIndex {
class ShapeIndexView {
public:
ShapeIndexView(const ShapeIndex& shape_index, int64 offset = 0)
- : ShapeIndexView(shape_index.data() + offset,
- shape_index.data() + shape_index.size()) {
+ : indices_(shape_index.data() + offset, shape_index.size() - offset) {
CHECK_LE(offset, shape_index.size());
}
- ShapeIndexView(std::initializer_list<int64> indices)
- : ShapeIndexView(indices.begin(), indices.end()) {}
+ ShapeIndexView(std::initializer_list<int64> indices) : indices_(indices) {}
ShapeIndexView(const ShapeIndexView& other) = default;
using iterator = const int64*;
- iterator begin() const { return begin_; }
- iterator end() const { return end_; }
- int64 size() const { return std::distance(begin_, end_); }
- bool empty() const { return begin_ == end_; }
+ iterator begin() const { return indices_.begin(); }
+ iterator end() const { return indices_.end(); }
+ int64 size() const { return indices_.size(); }
+ bool empty() const { return indices_.empty(); }
int64 front() const {
CHECK(!empty());
- return *begin_;
+ return indices_.front();
}
ShapeIndexView ConsumeFront() const {
- CHECK(!empty());
- auto new_begin = begin_;
- ++new_begin;
- return ShapeIndexView(new_begin, end_);
+ ShapeIndexView result = *this;
+ result.indices_.pop_front();
+ return result;
+ }
+ ShapeIndexView ConsumeBack() const {
+ ShapeIndexView result = *this;
+ result.indices_.pop_back();
+ return result;
}
- ShapeIndex ToShapeIndex() const { return ShapeIndex(begin_, end_); }
+ ShapeIndex ToShapeIndex() const { return ShapeIndex(begin(), end()); }
bool operator==(const ShapeIndexView& other) const;
bool operator!=(const ShapeIndexView& other) const;
@@ -142,10 +147,7 @@ class ShapeIndexView {
string ToString() const;
private:
- ShapeIndexView(iterator begin, iterator end) : begin_(begin), end_(end) {}
-
- iterator begin_;
- iterator end_;
+ tensorflow::gtl::ArraySlice<int64> indices_;
};
std::ostream& operator<<(std::ostream& out, const ShapeIndex& shape_index);
@@ -530,7 +532,13 @@ class ShapeUtil {
static bool HasDegenerateDimensions(const Shape& shape);
// Permutes the dimensions by the given permutation, so
- // return_value.dimensions[permutation[i]] = argument.dimensions[i]
+ // return_value.dimensions[permutation[i]] = argument.dimensions[i].
+ //
+ // Postcondition: For any valid permutation,
+ //
+ // !HasLayout(shape) ||
+ // TransposeIsBitcast(shape, PermuteDimensions(permutation, shape),
+ // InversePermutation(permutation)).
static Shape PermuteDimensions(tensorflow::gtl::ArraySlice<int64> permutation,
const Shape& shape);