diff options
Diffstat (limited to 'tensorflow/core/lib/gtl/array_slice_internal.h')
-rw-r--r-- | tensorflow/core/lib/gtl/array_slice_internal.h | 253 |
1 files changed, 253 insertions, 0 deletions
diff --git a/tensorflow/core/lib/gtl/array_slice_internal.h b/tensorflow/core/lib/gtl/array_slice_internal.h new file mode 100644 index 0000000000..080f0a38d8 --- /dev/null +++ b/tensorflow/core/lib/gtl/array_slice_internal.h @@ -0,0 +1,253 @@ +// NOT FOR INCLUSION BY CLIENT CODE. This file is only to be included by +// array_slice.h. + +// Helper functions and templates for ArraySlice. + +#ifndef TENSORFLOW_LIB_GTL_ARRAY_SLICE_INTERNAL_H_ +#define TENSORFLOW_LIB_GTL_ARRAY_SLICE_INTERNAL_H_ + +#include <stddef.h> +#include <algorithm> +#include <iterator> +#include <memory> +#include <string> +#include <type_traits> +#include <utility> +#include "tensorflow/core/platform/logging.h" + +namespace tensorflow { +namespace gtl { +namespace array_slice_internal { + +// Template logic for generic constructors. + +// Wrappers whose Get() delegates to the appropriate method of a container, and +// is defined when this method exists. Delegates to the const method if C is a +// const type. +struct Data { + template <typename C> + static decltype(std::declval<C>().data()) Get(C* v) { + return v->data(); + } +}; + +struct MutableData { + template <typename C> + static decltype(std::declval<C>().mutable_data()) Get(C* v) { + return v->mutable_data(); + } +}; + +struct Size { + template <typename C> + static decltype(std::declval<C>().size()) Get(C* v) { + return v->size(); + } +}; + +struct MutableStringData { + // Defined only for string. + static char* Get(string* v) { return v->empty() ? nullptr : &*v->begin(); } +}; + +// Checks whether M::Get(C*) is defined and has a return type R such that +// Checker::valid<R>()==true. +template <typename M, typename Checker, typename C> +struct HasGetHelper : public M { + private: + struct None {}; + // M::Get is selected when it is viable. Get(...) is selected otherwise. + using M::Get; + static None Get(...); + + public: + static constexpr bool HasGet() { + using Result = decltype(Get(std::declval<C*>())); + return !std::is_same<Result, None>() && Checker::template valid<Result>(); + } +}; + +// Defines HasGet() for a particular method, container, and checker. If +// HasGet()==true, provides Get() that delegates to the method. +template <typename M, typename Checker, typename C, + bool /*has_get*/ = HasGetHelper<M, Checker, C>::HasGet()> +struct Wrapper { + static constexpr bool HasGet() { return false; } +}; + +template <typename M, typename Checker, typename C> +struct Wrapper<M, Checker, C, true> { + static constexpr bool HasGet() { return true; } + static decltype(M::Get(std::declval<C*>())) Get(C* v) { return M::Get(v); } +}; + +// Type checker for a method returning an integral value. +struct SizeChecker { + template <typename R> + static constexpr bool valid() { + return std::is_integral<R>::value; + } +}; + +// Type checker for a method returning either a pointer to T or a less const +// version of that. +template <typename T> +struct DataChecker { + // We want to enable conversion from std::vector<T*> to ArraySlice<const T*> + // but + // disable conversion from std::vector<Derived> to ArraySlice<Base>. Here we + // use + // the fact that U** is convertible to Q* const* if and only if Q is the same + // type or a more cv-qualified version of U. + template <typename R> + static constexpr bool valid() { + return std::is_convertible<R*, T* const*>::value; + } +}; + +// Aliases to A if A::HasGet()==true, or to B otherwise. +template <typename A, typename B> +using FirstWithGet = typename std::conditional<A::HasGet(), A, B>::type; + +// Wraps C::data() const, returning a pointer to const data. +template <typename T, typename C> +using ContainerData = Wrapper<Data, DataChecker<const T>, const C>; + +// Wraps a method returning a pointer to mutable data. Prefers data() over +// mutable_data(), and handles strings when T==char. If data() returns a pointer +// to mutable data, it is most likely overloaded, but may also be a single +// method 'T* C::data() const' in a non-STL-compliant container. +template <typename T, typename C> +using ContainerMutableData = + FirstWithGet<Wrapper<Data, DataChecker<T>, C>, + FirstWithGet<Wrapper<MutableData, DataChecker<T>, C>, + Wrapper<MutableStringData, DataChecker<T>, C>>>; + +// Wraps C::size() const. +template <typename C> +using ContainerSize = Wrapper<Size, SizeChecker, const C>; + +// Implementation class for ArraySlice and MutableArraySlice. In the case of +// ArraySlice, T will be a const type; for MutableArraySlice, T will be a +// mutable type. +template <typename T> +class ArraySliceImplBase { + public: + typedef T* pointer; + typedef const T* const_pointer; + typedef T& reference; + typedef const T& const_reference; + typedef pointer iterator; + typedef const_pointer const_iterator; + typedef std::reverse_iterator<iterator> reverse_iterator; + typedef std::reverse_iterator<const_iterator> const_reverse_iterator; + typedef size_t size_type; + typedef ptrdiff_t difference_type; + + static const size_type npos = -1; + + ArraySliceImplBase(pointer array, size_type length) + : ptr_(array), length_(length) {} + + // Substring of another ArraySlice. + // pos must be non-negative and <= x.length(). + // len must be non-negative and will be pinned to at most x.length() - pos. + ArraySliceImplBase(const ArraySliceImplBase& x, size_type pos, size_type len) + : ptr_(x.ptr_ + pos), length_(std::min(x.length_ - pos, len)) {} + + // Some of the const methods below return pointers and references to mutable + // data. This is only the case in this internal class; ArraySlice and + // MutableArraySlice provide deep-constness. + + pointer data() const { return ptr_; } + size_type size() const { return length_; } + + void clear() { + ptr_ = nullptr; + length_ = 0; + } + + reference operator[](size_type i) const { return ptr_[i]; } + reference at(size_type i) const { + DCHECK_LT(i, length_); + return ptr_[i]; + } + reference front() const { + DCHECK_GT(length_, 0); + return ptr_[0]; + } + reference back() const { + DCHECK_GT(length_, 0); + return ptr_[length_ - 1]; + } + + void remove_prefix(size_type n) { + DCHECK_GE(length_, n); + ptr_ += n; + length_ -= n; + } + void remove_suffix(size_type n) { + DCHECK_GE(length_, n); + length_ -= n; + } + + iterator begin() const { return ptr_; } + iterator end() const { return ptr_ + length_; } + reverse_iterator rbegin() const { return reverse_iterator(end()); } + reverse_iterator rend() const { return reverse_iterator(begin()); } + + bool operator==(const ArraySliceImplBase& other) const { + if (size() != other.size()) return false; + if (data() == other.data()) return true; + return std::equal(data(), data() + size(), other.data()); + } + bool operator!=(const ArraySliceImplBase& other) const { + return !(*this == other); + } + + private: + pointer ptr_; + size_type length_; +}; + +template <typename T> +class ArraySliceImpl : public ArraySliceImplBase<const T> { + public: + using ArraySliceImplBase<const T>::ArraySliceImplBase; + + // Defined iff the data and size accessors for the container C have been + // defined. + template <typename C> + using EnableIfConvertibleFrom = + typename std::enable_if<ContainerData<T, C>::HasGet() && + ContainerSize<C>::HasGet()>::type; + + // Constructs from a container when EnableIfConvertibleFrom is + // defined. std::addressof handles types with overloaded operator&. + template <typename C> + explicit ArraySliceImpl(const C& v) + : ArraySliceImplBase<const T>(ContainerData<T, C>::Get(std::addressof(v)), + ContainerSize<C>::Get(std::addressof(v))) {} +}; + +template <typename T> +class MutableArraySliceImpl : public ArraySliceImplBase<T> { + public: + using ArraySliceImplBase<T>::ArraySliceImplBase; + + template <typename C> + using EnableIfConvertibleFrom = + typename std::enable_if<ContainerMutableData<T, C>::HasGet() && + ContainerSize<C>::HasGet()>::type; + + template <typename C> + explicit MutableArraySliceImpl(C* v) + : ArraySliceImplBase<T>(ContainerMutableData<T, C>::Get(v), + ContainerSize<C>::Get(v)) {} +}; + +} // namespace array_slice_internal +} // namespace gtl +} // namespace tensorflow + +#endif // TENSORFLOW_LIB_GTL_ARRAY_SLICE_INTERNAL_H_ |