aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/core/lib/gtl/array_slice_internal.h
blob: 080f0a38d8842707d9be6c7d7f8926d129a972da (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
// NOT FOR INCLUSION BY CLIENT CODE. This file is only to be included by
// array_slice.h.

// Helper functions and templates for ArraySlice.

#ifndef TENSORFLOW_LIB_GTL_ARRAY_SLICE_INTERNAL_H_
#define TENSORFLOW_LIB_GTL_ARRAY_SLICE_INTERNAL_H_

#include <stddef.h>
#include <algorithm>
#include <iterator>
#include <memory>
#include <string>
#include <type_traits>
#include <utility>
#include "tensorflow/core/platform/logging.h"

namespace tensorflow {
namespace gtl {
namespace array_slice_internal {

// Template logic for generic constructors.

// Wrappers whose Get() delegates to the appropriate method of a container, and
// is defined when this method exists. Delegates to the const method if C is a
// const type.
struct Data {
  template <typename C>
  static decltype(std::declval<C>().data()) Get(C* v) {
    return v->data();
  }
};

struct MutableData {
  template <typename C>
  static decltype(std::declval<C>().mutable_data()) Get(C* v) {
    return v->mutable_data();
  }
};

struct Size {
  template <typename C>
  static decltype(std::declval<C>().size()) Get(C* v) {
    return v->size();
  }
};

struct MutableStringData {
  // Defined only for string.
  static char* Get(string* v) { return v->empty() ? nullptr : &*v->begin(); }
};

// Checks whether M::Get(C*) is defined and has a return type R such that
// Checker::valid<R>()==true.
template <typename M, typename Checker, typename C>
struct HasGetHelper : public M {
 private:
  struct None {};
  // M::Get is selected when it is viable. Get(...) is selected otherwise.
  using M::Get;
  static None Get(...);

 public:
  static constexpr bool HasGet() {
    using Result = decltype(Get(std::declval<C*>()));
    return !std::is_same<Result, None>() && Checker::template valid<Result>();
  }
};

// Defines HasGet() for a particular method, container, and checker. If
// HasGet()==true, provides Get() that delegates to the method.
template <typename M, typename Checker, typename C,
          bool /*has_get*/ = HasGetHelper<M, Checker, C>::HasGet()>
struct Wrapper {
  static constexpr bool HasGet() { return false; }
};

template <typename M, typename Checker, typename C>
struct Wrapper<M, Checker, C, true> {
  static constexpr bool HasGet() { return true; }
  static decltype(M::Get(std::declval<C*>())) Get(C* v) { return M::Get(v); }
};

// Type checker for a method returning an integral value.
struct SizeChecker {
  template <typename R>
  static constexpr bool valid() {
    return std::is_integral<R>::value;
  }
};

// Type checker for a method returning either a pointer to T or a less const
// version of that.
template <typename T>
struct DataChecker {
  // We want to enable conversion from std::vector<T*> to ArraySlice<const T*>
  // but
  // disable conversion from std::vector<Derived> to ArraySlice<Base>. Here we
  // use
  // the fact that U** is convertible to Q* const* if and only if Q is the same
  // type or a more cv-qualified version of U.
  template <typename R>
  static constexpr bool valid() {
    return std::is_convertible<R*, T* const*>::value;
  }
};

// Aliases to A if A::HasGet()==true, or to B otherwise.
template <typename A, typename B>
using FirstWithGet = typename std::conditional<A::HasGet(), A, B>::type;

// Wraps C::data() const, returning a pointer to const data.
template <typename T, typename C>
using ContainerData = Wrapper<Data, DataChecker<const T>, const C>;

// Wraps a method returning a pointer to mutable data. Prefers data() over
// mutable_data(), and handles strings when T==char. If data() returns a pointer
// to mutable data, it is most likely overloaded, but may also be a single
// method 'T* C::data() const' in a non-STL-compliant container.
template <typename T, typename C>
using ContainerMutableData =
    FirstWithGet<Wrapper<Data, DataChecker<T>, C>,
                 FirstWithGet<Wrapper<MutableData, DataChecker<T>, C>,
                              Wrapper<MutableStringData, DataChecker<T>, C>>>;

// Wraps C::size() const.
template <typename C>
using ContainerSize = Wrapper<Size, SizeChecker, const C>;

// Implementation class for ArraySlice and MutableArraySlice. In the case of
// ArraySlice, T will be a const type; for MutableArraySlice, T will be a
// mutable type.
template <typename T>
class ArraySliceImplBase {
 public:
  typedef T* pointer;
  typedef const T* const_pointer;
  typedef T& reference;
  typedef const T& const_reference;
  typedef pointer iterator;
  typedef const_pointer const_iterator;
  typedef std::reverse_iterator<iterator> reverse_iterator;
  typedef std::reverse_iterator<const_iterator> const_reverse_iterator;
  typedef size_t size_type;
  typedef ptrdiff_t difference_type;

  static const size_type npos = -1;

  ArraySliceImplBase(pointer array, size_type length)
      : ptr_(array), length_(length) {}

  // Substring of another ArraySlice.
  // pos must be non-negative and <= x.length().
  // len must be non-negative and will be pinned to at most x.length() - pos.
  ArraySliceImplBase(const ArraySliceImplBase& x, size_type pos, size_type len)
      : ptr_(x.ptr_ + pos), length_(std::min(x.length_ - pos, len)) {}

  // Some of the const methods below return pointers and references to mutable
  // data. This is only the case in this internal class; ArraySlice and
  // MutableArraySlice provide deep-constness.

  pointer data() const { return ptr_; }
  size_type size() const { return length_; }

  void clear() {
    ptr_ = nullptr;
    length_ = 0;
  }

  reference operator[](size_type i) const { return ptr_[i]; }
  reference at(size_type i) const {
    DCHECK_LT(i, length_);
    return ptr_[i];
  }
  reference front() const {
    DCHECK_GT(length_, 0);
    return ptr_[0];
  }
  reference back() const {
    DCHECK_GT(length_, 0);
    return ptr_[length_ - 1];
  }

  void remove_prefix(size_type n) {
    DCHECK_GE(length_, n);
    ptr_ += n;
    length_ -= n;
  }
  void remove_suffix(size_type n) {
    DCHECK_GE(length_, n);
    length_ -= n;
  }

  iterator begin() const { return ptr_; }
  iterator end() const { return ptr_ + length_; }
  reverse_iterator rbegin() const { return reverse_iterator(end()); }
  reverse_iterator rend() const { return reverse_iterator(begin()); }

  bool operator==(const ArraySliceImplBase& other) const {
    if (size() != other.size()) return false;
    if (data() == other.data()) return true;
    return std::equal(data(), data() + size(), other.data());
  }
  bool operator!=(const ArraySliceImplBase& other) const {
    return !(*this == other);
  }

 private:
  pointer ptr_;
  size_type length_;
};

template <typename T>
class ArraySliceImpl : public ArraySliceImplBase<const T> {
 public:
  using ArraySliceImplBase<const T>::ArraySliceImplBase;

  // Defined iff the data and size accessors for the container C have been
  // defined.
  template <typename C>
  using EnableIfConvertibleFrom =
      typename std::enable_if<ContainerData<T, C>::HasGet() &&
                              ContainerSize<C>::HasGet()>::type;

  // Constructs from a container when EnableIfConvertibleFrom is
  // defined. std::addressof handles types with overloaded operator&.
  template <typename C>
  explicit ArraySliceImpl(const C& v)
      : ArraySliceImplBase<const T>(ContainerData<T, C>::Get(std::addressof(v)),
                                    ContainerSize<C>::Get(std::addressof(v))) {}
};

template <typename T>
class MutableArraySliceImpl : public ArraySliceImplBase<T> {
 public:
  using ArraySliceImplBase<T>::ArraySliceImplBase;

  template <typename C>
  using EnableIfConvertibleFrom =
      typename std::enable_if<ContainerMutableData<T, C>::HasGet() &&
                              ContainerSize<C>::HasGet()>::type;

  template <typename C>
  explicit MutableArraySliceImpl(C* v)
      : ArraySliceImplBase<T>(ContainerMutableData<T, C>::Get(v),
                              ContainerSize<C>::Get(v)) {}
};

}  // namespace array_slice_internal
}  // namespace gtl
}  // namespace tensorflow

#endif  // TENSORFLOW_LIB_GTL_ARRAY_SLICE_INTERNAL_H_